Class: Transformers::DebertaV2::DisentangledSelfAttention
- Inherits:
-
Torch::NN::Module
- Object
- Torch::NN::Module
- Transformers::DebertaV2::DisentangledSelfAttention
- Defined in:
- lib/transformers/models/deberta_v2/modeling_deberta_v2.rb
Instance Method Summary collapse
- #disentangled_attention_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor) ⇒ Object
- #forward(hidden_states, attention_mask, output_attentions: false, query_states: nil, relative_pos: nil, rel_embeddings: nil) ⇒ Object
-
#initialize(config) ⇒ DisentangledSelfAttention
constructor
A new instance of DisentangledSelfAttention.
- #transpose_for_scores(x, attention_heads) ⇒ Object
Constructor Details
#initialize(config) ⇒ DisentangledSelfAttention
Returns a new instance of DisentangledSelfAttention.
462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 |
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 462 def initialize(config) super() if config.hidden_size % config.num_attention_heads != 0 raise ArgumentError, "The hidden size (#{config.hidden_size}) is not a multiple of the number of attention heads (#{config.num_attention_heads})" end @num_attention_heads = config.num_attention_heads _attention_head_size = config.hidden_size / config.num_attention_heads @attention_head_size = config.getattr("attention_head_size", _attention_head_size) @all_head_size = @num_attention_heads * @attention_head_size @query_proj = Torch::NN::Linear.new(config.hidden_size, @all_head_size, bias: true) @key_proj = Torch::NN::Linear.new(config.hidden_size, @all_head_size, bias: true) @value_proj = Torch::NN::Linear.new(config.hidden_size, @all_head_size, bias: true) @share_att_key = config.getattr("share_att_key", false) @pos_att_type = !config.pos_att_type.nil? ? config.pos_att_type : [] @relative_attention = config.getattr("relative_attention", false) if @relative_attention @position_buckets = config.getattr("position_buckets", -1) @max_relative_positions = config.getattr("max_relative_positions", -1) if @max_relative_positions < 1 @max_relative_positions = config. end @pos_ebd_size = @max_relative_positions if @position_buckets > 0 @pos_ebd_size = @position_buckets end @pos_dropout = StableDropout.new(config.hidden_dropout_prob) if !@share_att_key if @pos_att_type.include?("c2p") @pos_key_proj = Torch::NN::Linear.new(config.hidden_size, @all_head_size, bias: true) end if @pos_att_type.include?("p2c") @pos_query_proj = Torch::NN::Linear.new(config.hidden_size, @all_head_size) end end end @dropout = StableDropout.new(config.attention_probs_dropout_prob) end |
Instance Method Details
#disentangled_attention_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor) ⇒ Object
562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 |
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 562 def disentangled_attention_bias(query_layer, key_layer, relative_pos, , scale_factor) if relative_pos.nil? q = query_layer.size(-2) relative_pos = DebertaV2.build_relative_position(q, key_layer.size(-2), bucket_size: @position_buckets, max_position: @max_relative_positions, device: query_layer.device) end if relative_pos.dim == 2 relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) elsif relative_pos.dim == 3 relative_pos = relative_pos.unsqueeze(1) elsif relative_pos.dim != 4 raise ArgumentError, "Relative position ids must be of dim 2 or 3 or 4. #{relative_pos.dim}" end att_span = @pos_ebd_size relative_pos = relative_pos.long.to(query_layer.device) = [0...att_span * 2, 0..].unsqueeze(0) if @share_att_key pos_query_layer = transpose_for_scores(@query_proj.(), @num_attention_heads).repeat(query_layer.size(0) / @num_attention_heads, 1, 1) pos_key_layer = transpose_for_scores(@key_proj.(), @num_attention_heads).repeat(query_layer.size(0) / @num_attention_heads, 1, 1) elsif @pos_att_type.include?("c2p") pos_key_layer = transpose_for_scores(@pos_key_proj.(), @num_attention_heads).repeat(query_layer.size(0) / @num_attention_heads, 1, 1) end score = 0 # content->position if @pos_att_type.include?("c2p") scale = Torch.sqrt(Torch.tensor(pos_key_layer.size(-1), dtype: Torch.float) * scale_factor) c2p_att = Torch.bmm(query_layer, pos_key_layer.transpose(-1, -2)) c2p_pos = Torch.clamp(relative_pos + att_span, 0, (att_span * 2) - 1) c2p_att = Torch.gather(c2p_att, dim: -1, index: c2p_pos.squeeze(0).([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)])) score += c2p_att / scale.to(dtype: c2p_att.dtype) end # position->content if @pos_att_type.include?("p2c") scale = Torch.sqrt(Torch.tensor(pos_query_layer.size(-1), dtype: Torch.float) * scale_factor) if key_layer.size(-2) != query_layer.size(-2) r_pos = DebertaV2.build_relative_position(key_layer.size(-2), key_layer.size(-2), bucket_size: @position_buckets, max_position: @max_relative_positions, device: query_layer.device) r_pos = r_pos.unsqueeze(0) else r_pos = relative_pos end p2c_pos = Torch.clamp(-r_pos + att_span, 0, (att_span * 2) - 1) p2c_att = Torch.bmm(key_layer, pos_query_layer.transpose(-1, -2)) p2c_att = Torch.gather(p2c_att, dim: -1, index: p2c_pos.squeeze(0).([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)])).transpose(-1, -2) score += p2c_att / scale.to(dtype: p2c_att.dtype) end score end |
#forward(hidden_states, attention_mask, output_attentions: false, query_states: nil, relative_pos: nil, rel_embeddings: nil) ⇒ Object
511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 |
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 511 def forward( hidden_states, attention_mask, output_attentions: false, query_states: nil, relative_pos: nil, rel_embeddings: nil ) if query_states.nil? query_states = hidden_states end query_layer = transpose_for_scores(@query_proj.(query_states), @num_attention_heads) key_layer = transpose_for_scores(@key_proj.(hidden_states), @num_attention_heads) value_layer = transpose_for_scores(@value_proj.(hidden_states), @num_attention_heads) rel_att = nil # Take the dot product between "query" and "key" to get the raw attention scores. scale_factor = 1 if @pos_att_type.include?("c2p") scale_factor += 1 end if @pos_att_type.include?("p2c") scale_factor += 1 end scale = Torch.sqrt(Torch.tensor(query_layer.size(-1), dtype: Torch.float) * scale_factor) attention_scores = Torch.bmm(query_layer, key_layer.transpose(-1, -2) / scale.to(dtype: query_layer.dtype)) if @relative_attention = @pos_dropout.() rel_att = disentangled_attention_bias(query_layer, key_layer, relative_pos, , scale_factor) end if !rel_att.nil? attention_scores = attention_scores + rel_att end attention_scores = attention_scores attention_scores = attention_scores.view(-1, @num_attention_heads, attention_scores.size(-2), attention_scores.size(-1)) # bsz x height x length x dimension attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) attention_probs = @dropout.(attention_probs) context_layer = Torch.bmm(attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer) context_layer = context_layer.view(-1, @num_attention_heads, context_layer.size(-2), context_layer.size(-1)).permute(0, 2, 1, 3).contiguous new_context_layer_shape = context_layer.size[...-2] + [-1] context_layer = context_layer.view(new_context_layer_shape) if output_attentions [context_layer, attention_probs] else context_layer end end |
#transpose_for_scores(x, attention_heads) ⇒ Object
505 506 507 508 509 |
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 505 def transpose_for_scores(x, attention_heads) new_x_shape = x.size[...-1] + [attention_heads, -1] x = x.view(new_x_shape) x.permute(0, 2, 1, 3).contiguous.view(-1, x.size(1), x.size(-1)) end |