Class: Informers::PreTrainedTokenizer

Inherits:
Object
  • Object
show all
Defined in:
lib/informers/tokenizers.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(tokenizer_json, tokenizer_config) ⇒ PreTrainedTokenizer

Returns a new instance of PreTrainedTokenizer.



5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# File 'lib/informers/tokenizers.rb', line 5

def initialize(tokenizer_json, tokenizer_config)
  super()

  @tokenizer_config = tokenizer_config

  @tokenizer = Tokenizers::Tokenizer.from_file(tokenizer_json)

  # Add added_tokens to model
  @special_tokens = []
  @all_special_ids = []

  @added_tokens = []
  @tokenizer.added_tokens_decoder.each do |id, token|
    @added_tokens << token

    if token.special
      @special_tokens << token.content
      @all_special_ids << id
    end
  end

  # Update additional_special_tokens
  @additional_special_tokens = tokenizer_config["additional_special_tokens"] || []
  @special_tokens.concat(@additional_special_tokens)

  @mask_token = get_token("mask_token")
  @mask_token_id = @tokenizer.token_to_id(@mask_token) if @mask_token

  @sep_token = get_token("sep_token")
  @sep_token_id = @tokenizer.token_to_id(@sep_token) if @sep_token

  @model_max_length = tokenizer_config["model_max_length"]

  # for donut-base-finetuned-docvqa
  if @model_max_length && @model_max_length > (1 << 63)
    @model_max_length = 1 << 63
  end
end

Instance Attribute Details

#mask_tokenObject (readonly)

Returns the value of attribute mask_token.



3
4
5
# File 'lib/informers/tokenizers.rb', line 3

def mask_token
  @mask_token
end

#mask_token_idObject (readonly)

Returns the value of attribute mask_token_id.



3
4
5
# File 'lib/informers/tokenizers.rb', line 3

def mask_token_id
  @mask_token_id
end

#sep_token_idObject (readonly)

Returns the value of attribute sep_token_id.



3
4
5
# File 'lib/informers/tokenizers.rb', line 3

def sep_token_id
  @sep_token_id
end

Instance Method Details

#batch_decode(batch, **decode_args) ⇒ Object



137
138
139
# File 'lib/informers/tokenizers.rb', line 137

def batch_decode(batch, **decode_args)
  @tokenizer.decode_batch(batch, **decode_args)
end

#call(text, text_pair: nil, add_special_tokens: true, padding: false, truncation: nil, max_length: nil, return_tensor: true, return_token_type_ids: true, return_offsets: false) ⇒ Object



65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# File 'lib/informers/tokenizers.rb', line 65

def call(
  text,
  text_pair: nil,
  add_special_tokens: true,
  padding: false,
  truncation: nil,
  max_length: nil,
  return_tensor: true,
  return_token_type_ids: true, # TODO change default
  return_offsets: false
)
  is_batched = text.is_a?(Array)

  if is_batched
    if text.length == 0
      raise Error, "text array must be non-empty"
    end

    if !text_pair.nil?
      if !text_pair.is_a?(Array)
        raise Error, "text_pair must also be an array"
      elsif text.length != text_pair.length
        raise Error, "text and text_pair must have the same length"
      end
    end
  end

  if padding
    @tokenizer.enable_padding
  else
    @tokenizer.no_padding
  end

  if truncation
    @tokenizer.enable_truncation(max_length || @model_max_length)
  else
    @tokenizer.no_truncation
  end

  if is_batched
    input = text_pair ? text.zip(text_pair) : text
    encoded = @tokenizer.encode_batch(input, add_special_tokens: add_special_tokens)
  else
    encoded = [@tokenizer.encode(text, text_pair, add_special_tokens: add_special_tokens)]
  end

  result = {input_ids: encoded.map(&:ids), attention_mask: encoded.map(&:attention_mask)}
  if return_token_type_ids
    result[:token_type_ids] = encoded.map(&:type_ids)
  end
  if return_offsets
    result[:offsets] = encoded.map(&:offsets)
  end
  result
end

#convert_tokens_to_ids(tokens) ⇒ Object



129
130
131
# File 'lib/informers/tokenizers.rb', line 129

def convert_tokens_to_ids(tokens)
  tokens.map { |t| @tokenizer.token_to_id(t) }
end

#convert_tokens_to_string(tokens) ⇒ Object



125
126
127
# File 'lib/informers/tokenizers.rb', line 125

def convert_tokens_to_string(tokens)
  @tokenizer.decoder.decode(tokens)
end

#decode(tokens, skip_special_tokens:) ⇒ Object



121
122
123
# File 'lib/informers/tokenizers.rb', line 121

def decode(tokens, skip_special_tokens:)
  @tokenizer.decode(tokens, skip_special_tokens: skip_special_tokens)
end

#get_token(*keys) ⇒ Object



44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# File 'lib/informers/tokenizers.rb', line 44

def get_token(*keys)
  keys.each do |key|
    item = @tokenizer_config[key]
    if !item
      next
    end

    if item.is_a?(Hash)
      if item["__type"] == "AddedToken"
        return item["content"]
      else
        raise Error, "Unknown token: #{item}"
      end
    else
      return item
    end
  end

  nil
end

#id_to_token(id) ⇒ Object



133
134
135
# File 'lib/informers/tokenizers.rb', line 133

def id_to_token(id)
  @tokenizer.id_to_token(id)
end

#padding_side=(side) ⇒ Object



141
142
143
# File 'lib/informers/tokenizers.rb', line 141

def padding_side=(side)
  @tokenizer.enable_padding(direction: side)
end