Class: Transformers::Distilbert::TransformerBlock

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/transformers/models/distilbert/modeling_distilbert.rb

Instance Method Summary collapse

Constructor Details

#initialize(config) ⇒ TransformerBlock

Returns a new instance of TransformerBlock.



183
184
185
186
187
188
189
190
191
192
193
194
195
196
# File 'lib/transformers/models/distilbert/modeling_distilbert.rb', line 183

def initialize(config)
  super()

  # Have an even number of Configure multi-heads
  if config.dim % config.n_heads != 0
    raise ArgumentError, "config.n_heads #{config.n_heads} must divide config.dim #{config.dim} evenly"
  end

  @attention = DISTILBERT_ATTENTION_CLASSES[config._attn_implementation].new(config)
  @sa_layer_norm = Torch::NN::LayerNorm.new(config.dim, eps: 1e-12)

  @ffn = FFN.new(config)
  @output_layer_norm = Torch::NN::LayerNorm.new(config.dim, eps: 1e-12)
end

Instance Method Details

#forward(x:, attn_mask: nil, head_mask: nil, output_attentions: false) ⇒ Object



198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
# File 'lib/transformers/models/distilbert/modeling_distilbert.rb', line 198

def forward(
  x:,
  attn_mask: nil,
  head_mask: nil,
  output_attentions: false
)
  # Self-Attention
  sa_output =
    @attention.(
      query: x,
      key: x,
      value: x,
      mask: attn_mask,
      head_mask: head_mask,
      output_attentions: output_attentions,
    )
  if output_attentions
    sa_output, sa_weights = sa_output  # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
  else  # To handle these `output_attentions` or `output_hidden_states` cases returning tuples
    if !sa_output.is_a?(Array)
      raise TypeError, "sa_output must be an array but it is #{sa_output.class.name} type"
    end

    sa_output = sa_output[0]
  end
  sa_output = @sa_layer_norm.(sa_output + x)  # (bs, seq_length, dim)

  # Feed Forward Network
  ffn_output = @ffn.(sa_output)  # (bs, seq_length, dim)
  ffn_output = @output_layer_norm.(ffn_output + sa_output)  # (bs, seq_length, dim)

  output = [ffn_output]
  if output_attentions
    output = [sa_weights] + output
  end
  output
end