Class: Transformers::DebertaV2::DebertaV2Attention

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/transformers/models/deberta_v2/modeling_deberta_v2.rb

Instance Method Summary collapse

Constructor Details

#initialize(config) ⇒ DebertaV2Attention

Returns a new instance of DebertaV2Attention.



165
166
167
168
169
170
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 165

def initialize(config)
  super()
  @self = DisentangledSelfAttention.new(config)
  @output = DebertaV2SelfOutput.new(config)
  @config = config
end

Instance Method Details

#forward(hidden_states, attention_mask, output_attentions: false, query_states: nil, relative_pos: nil, rel_embeddings: nil) ⇒ Object



172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 172

def forward(
  hidden_states,
  attention_mask,
  output_attentions: false,
  query_states: nil,
  relative_pos: nil,
  rel_embeddings: nil
)
  self_output = @self.(hidden_states, attention_mask, output_attentions:, query_states: query_states, relative_pos: relative_pos, rel_embeddings: rel_embeddings)
  if output_attentions
    self_output, att_matrix = self_output
  end
  if query_states.nil?
    query_states = hidden_states
  end
  attention_output = @output.(self_output, query_states)

  if output_attentions
    [attention_output, att_matrix]
  else
    attention_output
  end
end