Class: Transformers::XlmRoberta::XLMRobertaSelfAttention

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb

Direct Known Subclasses

XLMRobertaSdpaSelfAttention

Instance Method Summary collapse

Constructor Details

#initialize(config, position_embedding_type: nil) ⇒ XLMRobertaSelfAttention

Returns a new instance of XLMRobertaSelfAttention.



101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 101

def initialize(config, position_embedding_type: nil)
  super()
  if config.hidden_size % config.num_attention_heads != 0 && !config.hasattr("embedding_size")
    raise ArgumentError, "The hidden size (#{config.hidden_size}) is not a multiple of the number of attention heads (#{config.num_attention_heads})"
  end

  @num_attention_heads = config.num_attention_heads
  @attention_head_size = (config.hidden_size / config.num_attention_heads).to_i
  @all_head_size = @num_attention_heads * @attention_head_size

  @query = Torch::NN::Linear.new(config.hidden_size, @all_head_size)
  @key = Torch::NN::Linear.new(config.hidden_size, @all_head_size)
  @value = Torch::NN::Linear.new(config.hidden_size, @all_head_size)

  @dropout = Torch::NN::Dropout.new(p: config.attention_probs_dropout_prob)
  @position_embedding_type = position_embedding_type || config.getattr("position_embedding_type", "absolute")
  if @position_embedding_type == "relative_key" || @position_embedding_type == "relative_key_query"
    @max_position_embeddings = config.max_position_embeddings
    @distance_embedding = Torch::NN::Embedding.new((2 * config.max_position_embeddings) - 1, @attention_head_size)
  end

  @is_decoder = config.is_decoder
end

Instance Method Details

#forward(hidden_states, attention_mask: nil, head_mask: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, past_key_value: nil, output_attentions: false) ⇒ Object



131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 131

def forward(
  hidden_states,
  attention_mask: nil,
  head_mask: nil,
  encoder_hidden_states: nil,
  encoder_attention_mask: nil,
  past_key_value: nil,
  output_attentions: false
)
  mixed_query_layer = @query.(hidden_states)

  # If this is instantiated as a cross-attention module, the keys
  # and values come from an encoder; the attention mask needs to be
  # such that the encoder's padding tokens are not attended to.
  is_cross_attention = !encoder_hidden_states.nil?

  if is_cross_attention && !past_key_value.nil?
    # reuse k,v, cross_attentions
    key_layer = past_key_value[0]
    value_layer = past_key_value[1]
    attention_mask = encoder_attention_mask
  elsif is_cross_attention
    key_layer = transpose_for_scores(@key.(encoder_hidden_states))
    value_layer = transpose_for_scores(@value.(encoder_hidden_states))
    attention_mask = encoder_attention_mask
  elsif !past_key_value.nil?
    key_layer = transpose_for_scores(@key.(hidden_states))
    value_layer = transpose_for_scores(@value.(hidden_states))
    key_layer = Torch.cat([past_key_value[0], key_layer], dim: 2)
    value_layer = Torch.cat([past_key_value[1], value_layer], dim: 2)
  else
    key_layer = transpose_for_scores(@key.(hidden_states))
    value_layer = transpose_for_scores(@value.(hidden_states))
  end

  query_layer = transpose_for_scores(mixed_query_layer)

  use_cache = !past_key_value.nil?
  if @is_decoder
    # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
    # Further calls to cross_attention layer can then reuse all cross-attention
    # key/value_states (first "if" case)
    # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
    # all previous decoder key/value_states. Further calls to uni-directional self-attention
    # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
    # if encoder bi-directional self-attention `past_key_value` is always `None`
    past_key_value = [key_layer, value_layer]
  end

  # Take the dot product between "query" and "key" to get the raw attention scores.
  attention_scores = Torch.matmul(query_layer, key_layer.transpose(-1, -2))

  if @position_embedding_type == "relative_key" || @position_embedding_type == "relative_key_query"
    query_length, key_length = [query_layer.shape[2], key_layer.shape[2]]
    if use_cache
      position_ids_l = Torch.tensor(key_length - 1, dtype: Torch.long, device: hidden_states.device).view(-1, 1)
    else
      position_ids_l = Torch.arange(query_length, dtype: Torch.long, device: hidden_states.device).view(-1, 1)
    end
    position_ids_r = Torch.arange(key_length, dtype: Torch.long, device: hidden_states.device).view(1, -1)
    distance = position_ids_l - position_ids_r

    positional_embedding = @distance_embedding.((distance + @max_position_embeddings) - 1)
    positional_embedding = positional_embedding.to(dtype: query_layer.dtype)

    if @position_embedding_type == "relative_key"
      relative_position_scores = Torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
      attention_scores = attention_scores + relative_position_scores
    elsif @position_embedding_type == "relative_key_query"
      relative_position_scores_query = Torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
      relative_position_scores_key = Torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
      attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
    end
  end

  attention_scores = attention_scores / Math.sqrt(@attention_head_size)
  if !attention_mask.nil?
    # Apply the attention mask is (precomputed for all layers in XLMRobertaModel forward() function)
    attention_scores = attention_scores + attention_mask
  end

  # Normalize the attention scores to probabilities.
  attention_probs = Torch::NN::Functional.softmax(attention_scores, dim: -1)

  # This is actually dropping out entire tokens to attend to, which might
  # seem a bit unusual, but is taken from the original Transformer paper.
  attention_probs = @dropout.(attention_probs)

  # Mask heads if we want to
  if !head_mask.nil?
    attention_probs = attention_probs * head_mask
  end

  context_layer = Torch.matmul(attention_probs, value_layer)

  context_layer = context_layer.permute(0, 2, 1, 3).contiguous
  new_context_layer_shape = context_layer.size[...-2] + [@all_head_size]
  context_layer = context_layer.view(new_context_layer_shape)

  outputs = output_attentions ? [context_layer, attention_probs] : [context_layer]

  if @is_decoder
    outputs = outputs + [past_key_value]
  end
  outputs
end

#transpose_for_scores(x) ⇒ Object



125
126
127
128
129
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 125

def transpose_for_scores(x)
  new_x_shape = x.size[...-1] + [@num_attention_heads, @attention_head_size]
  x = x.view(new_x_shape)
  x.permute(0, 2, 1, 3)
end