5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
|
# File 'lib/torchtext/data/metrics.rb', line 5
def bleu_score(candidate_corpus, references_corpus, max_n: 4, weights: [0.25] * 4)
unless max_n == weights.length
raise "Length of the \"weights\" list has be equal to max_n"
end
unless candidate_corpus.length == references_corpus.length
raise "The length of candidate and reference corpus should be the same"
end
clipped_counts = Torch.zeros(max_n)
total_counts = Torch.zeros(max_n)
weights = Torch.tensor(weights)
candidate_len = 0.0
refs_len = 0.0
candidate_corpus.zip(references_corpus) do |candidate, refs|
candidate_len += candidate.length
refs_len_list = refs.map { |ref| ref.length.to_f }
refs_len += refs_len_list.min_by { |x| (candidate.length - x).abs }
reference_counters = compute_ngram_counter(refs[0], max_n)
refs[1..-1].each do |ref|
reference_counters = reference_counters.merge(compute_ngram_counter(ref, max_n)) { |_, v1, v2| v1 > v2 ? v1 : v2 }
end
candidate_counter = compute_ngram_counter(candidate, max_n)
shared_keys = candidate_counter.keys & reference_counters.keys
clipped_counter = candidate_counter.slice(*shared_keys).merge(reference_counters.slice(*shared_keys)) { |_, v1, v2| v1 < v2 ? v1 : v2 }
clipped_counter.each_key do |ngram|
clipped_counts[ngram.length - 1] += clipped_counter[ngram]
end
candidate_counter.each_key do |ngram|
total_counts[ngram.length - 1] += candidate_counter[ngram]
end
end
if clipped_counts.to_a.min == 0
0.0
else
pn = clipped_counts / total_counts
log_pn = weights * Torch.log(pn)
score = Torch.exp(log_pn.sum)
bp = Math.exp([1 - refs_len / candidate_len, 0].min)
bp * score.item
end
end
|