13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
|
# File 'lib/torchtext/nn/multihead_attention_container.rb', line 13
def forward(query, key, value, attn_mask: nil, bias_k: nil, bias_v: nil)
if @batch_first
query, key, value = query.transpose(-3, -2), key.transpose(-3, -2), value.transpose(-3, -2)
end
tgt_len, src_len, bsz, embed_dim = query.size(-3), key.size(-3), query.size(-2), query.size(-1)
q, k, v = @in_proj_container.call(query, key, value)
unless q.size(-1) % @nhead == 0
raise "query's embed_dim must be divisible by the number of heads"
end
head_dim = q.size(-1).div(@nhead)
q = q.reshape(tgt_len, bsz * @nhead, head_dim)
unless k.size(-1) % @nhead == 0
raise "key's embed_dim must be divisible by the number of heads"
end
head_dim = k.size(-1).div(@nhead)
k = k.reshape(src_len, bsz * @nhead, head_dim)
unless v.size(-1) % @nhead == 0
raise "value's embed_dim must be divisible by the number of heads"
end
head_dim = v.size(-1).div(@nhead)
v = v.reshape(src_len, bsz * @nhead, head_dim)
attn_output, attn_output_weights = @attention_layer.call(q, k, v, attn_mask: attn_mask, bias_k: bias_k, bias_v: bias_v)
attn_output = attn_output.reshape(tgt_len, bsz, embed_dim)
attn_output = @out_proj.call(attn_output)
if @batch_first
attn_output = attn_output.transpose(-3, -2)
end
[attn_output, attn_output_weights]
end
|