Class: Transformers::Bert::BertSelfAttention

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/transformers/models/bert/modeling_bert.rb

Instance Method Summary collapse

Constructor Details

#initialize(config, position_embedding_type: nil) ⇒ BertSelfAttention

Returns a new instance of BertSelfAttention.



80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# File 'lib/transformers/models/bert/modeling_bert.rb', line 80

def initialize(config, position_embedding_type: nil)
  super()
  if config.hidden_size % config.num_attention_heads != 0 && !config.embedding_size
    raise ArgumentError,
      "The hidden size (#{config.hidden_size}) is not a multiple of the number of attention " +
      "heads (#{config.num_attention_heads})"
  end

  @num_attention_heads = config.num_attention_heads
  @attention_head_size = (config.hidden_size / config.num_attention_heads).to_i
  @all_head_size = @num_attention_heads * @attention_head_size

  @query = Torch::NN::Linear.new(config.hidden_size, @all_head_size)
  @key = Torch::NN::Linear.new(config.hidden_size, @all_head_size)
  @value = Torch::NN::Linear.new(config.hidden_size, @all_head_size)

  @dropout = Torch::NN::Dropout.new(p: config.attention_probs_dropout_prob)
  @position_embedding_type = position_embedding_type || config.position_embedding_type || "absolute"
  if @position_embedding_type == "relative_key" || @position_embedding_type == "relative_key_query"
    @max_position_embeddings = config.max_position_embeddings
    @distance_embedding = Torch::NN::Embedding.new(2 * config.max_position_embeddings - 1, @attention_head_size)
  end

  @is_decoder = config.is_decoder
end

Instance Method Details

#forward(hidden_states, attention_mask: nil, head_mask: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, past_key_value: nil, output_attentions: false) ⇒ Object



112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
# File 'lib/transformers/models/bert/modeling_bert.rb', line 112

def forward(
  hidden_states,
  attention_mask: nil,
  head_mask: nil,
  encoder_hidden_states: nil,
  encoder_attention_mask: nil,
  past_key_value: nil,
  output_attentions: false
)
  mixed_query_layer = @query.(hidden_states)

  # If this is instantiated as a cross-attention module, the keys
  # and values come from an encoder; the attention mask needs to be
  # such that the encoder's padding tokens are not attended to.
  is_cross_attention = !encoder_hidden_states.nil?

  if is_cross_attention && !past_key_value.nil?
    # reuse k,v, cross_attentions
    key_layer = past_key_value[0]
    value_layer = past_key_value[1]
    attention_mask = encoder_attention_mask
  elsif is_cross_attention
    key_layer = transpose_for_scores(@key.(encoder_hidden_states))
    value_layer = transpose_for_scores(@value.(encoder_hidden_states))
    attention_mask = encoder_attention_mask
  elsif !past_key_value.nil?
    key_layer = transpose_for_scores(@key.(hidden_states))
    value_layer = transpose_for_scores(@value.(hidden_states))
    key_layer = Torch.cat([past_key_value[0], key_layer], dim: 2)
    value_layer = Torch.cat([past_key_value[1], value_layer], dim: 2)
  else
    key_layer = transpose_for_scores(@key.(hidden_states))
    value_layer = transpose_for_scores(@value.(hidden_states))
  end

  query_layer = transpose_for_scores(mixed_query_layer)

  _use_cache = !past_key_value.nil?
  if @is_decoder
    # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
    # Further calls to cross_attention layer can then reuse all cross-attention
    # key/value_states (first "if" case)
    # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
    # all previous decoder key/value_states. Further calls to uni-directional self-attention
    # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
    # if encoder bi-directional self-attention `past_key_value` is always `None`
    past_key_value = [key_layer, value_layer]
  end

  # Take the dot product between "query" and "key" to get the raw attention scores.
  attention_scores = Torch.matmul(query_layer, key_layer.transpose(-1, -2))

  if @position_embedding_type == "relative_key" || @position_embedding_type == "relative_key_query"
    raise Todo
  end

  attention_scores = attention_scores / Math.sqrt(@attention_head_size)
  if !attention_mask.nil?
    # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
    attention_scores = attention_scores + attention_mask
  end

  # Normalize the attention scores to probabilities.
  attention_probs = Torch::NN::Functional.softmax(attention_scores, dim: -1)

  # This is actually dropping out entire tokens to attend to, which might
  # seem a bit unusual, but is taken from the original Transformer paper.
  attention_probs = @dropout.(attention_probs)

  # Mask heads if we want to
  if !head_mask.nil?
    attention_probs = attention_probs * head_mask
  end

  context_layer = Torch.matmul(attention_probs, value_layer)

  context_layer = context_layer.permute(0, 2, 1, 3).contiguous
  new_context_layer_shape = context_layer.size[...-2] + [@all_head_size]
  context_layer = context_layer.view(new_context_layer_shape)

  outputs = output_attentions ? [context_layer, attention_probs] : [context_layer]

  if @is_decoder
    outputs = outputs + [past_key_value]
  end
  outputs
end

#transpose_for_scores(x) ⇒ Object



106
107
108
109
110
# File 'lib/transformers/models/bert/modeling_bert.rb', line 106

def transpose_for_scores(x)
  new_x_shape = x.size[...-1] + [@num_attention_heads, @attention_head_size]
  x = x.view(new_x_shape)
  x.permute(0, 2, 1, 3)
end