Class: Transformers::Distilbert::FFN

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/transformers/models/distilbert/modeling_distilbert.rb

Instance Method Summary collapse

Constructor Details

#initialize(config) ⇒ FFN

Returns a new instance of FFN.



154
155
156
157
158
159
160
161
162
# File 'lib/transformers/models/distilbert/modeling_distilbert.rb', line 154

def initialize(config)
  super()
  @dropout = Torch::NN::Dropout.new(p: config.dropout)
  @chunk_size_feed_forward = config.chunk_size_feed_forward
  @seq_len_dim = 1
  @lin1 = Torch::NN::Linear.new(config.dim, config.hidden_dim)
  @lin2 = Torch::NN::Linear.new(config.hidden_dim, config.dim)
  @activation = Activations.get_activation(config.activation)
end

Instance Method Details

#ff_chunk(input) ⇒ Object



168
169
170
171
172
173
174
# File 'lib/transformers/models/distilbert/modeling_distilbert.rb', line 168

def ff_chunk(input)
  x = @lin1.(input)
  x = @activation.(x)
  x = @lin2.(x)
  x = @dropout.(x)
  x
end

#forward(input) ⇒ Object



164
165
166
# File 'lib/transformers/models/distilbert/modeling_distilbert.rb', line 164

def forward(input)
  TorchUtils.apply_chunking_to_forward(method(:ff_chunk), @chunk_size_feed_forward, @seq_len_dim, input)
end