Class: Transformers::Distilbert::FFN
- Inherits:
-
Torch::NN::Module
- Object
- Torch::NN::Module
- Transformers::Distilbert::FFN
- Defined in:
- lib/transformers/models/distilbert/modeling_distilbert.rb
Instance Method Summary collapse
- #ff_chunk(input) ⇒ Object
- #forward(input) ⇒ Object
-
#initialize(config) ⇒ FFN
constructor
A new instance of FFN.
Constructor Details
#initialize(config) ⇒ FFN
Returns a new instance of FFN.
154 155 156 157 158 159 160 161 162 |
# File 'lib/transformers/models/distilbert/modeling_distilbert.rb', line 154 def initialize(config) super() @dropout = Torch::NN::Dropout.new(p: config.dropout) @chunk_size_feed_forward = config.chunk_size_feed_forward @seq_len_dim = 1 @lin1 = Torch::NN::Linear.new(config.dim, config.hidden_dim) @lin2 = Torch::NN::Linear.new(config.hidden_dim, config.dim) @activation = Activations.get_activation(config.activation) end |
Instance Method Details
#ff_chunk(input) ⇒ Object
168 169 170 171 172 173 174 |
# File 'lib/transformers/models/distilbert/modeling_distilbert.rb', line 168 def ff_chunk(input) x = @lin1.(input) x = @activation.(x) x = @lin2.(x) x = @dropout.(x) x end |
#forward(input) ⇒ Object
164 165 166 |
# File 'lib/transformers/models/distilbert/modeling_distilbert.rb', line 164 def forward(input) TorchUtils.apply_chunking_to_forward(method(:ff_chunk), @chunk_size_feed_forward, @seq_len_dim, input) end |