Class: Informers::PreTrainedTokenizer
- Inherits:
-
Object
- Object
- Informers::PreTrainedTokenizer
- Defined in:
- lib/informers/tokenizers.rb
Direct Known Subclasses
BartTokenizer, BertTokenizer, CLIPTokenizer, DebertaV2Tokenizer, DistilBertTokenizer, GPT2Tokenizer, M2M100Tokenizer, MPNetTokenizer, NllbTokenizer, RobertaTokenizer, SpeechT5Tokenizer, T5Tokenizer, XLMRobertaTokenizer
Instance Attribute Summary collapse
-
#mask_token ⇒ Object
readonly
Returns the value of attribute mask_token.
-
#mask_token_id ⇒ Object
readonly
Returns the value of attribute mask_token_id.
-
#sep_token_id ⇒ Object
readonly
Returns the value of attribute sep_token_id.
Instance Method Summary collapse
- #batch_decode(batch, **decode_args) ⇒ Object
- #call(text, text_pair: nil, add_special_tokens: true, padding: false, truncation: nil, max_length: nil, return_tensor: true, return_token_type_ids: true, return_offsets: false) ⇒ Object
- #convert_tokens_to_ids(tokens) ⇒ Object
- #convert_tokens_to_string(tokens) ⇒ Object
- #decode(tokens, skip_special_tokens:) ⇒ Object
- #get_token(*keys) ⇒ Object
- #id_to_token(id) ⇒ Object
-
#initialize(tokenizer_json, tokenizer_config) ⇒ PreTrainedTokenizer
constructor
A new instance of PreTrainedTokenizer.
- #padding_side=(side) ⇒ Object
Constructor Details
#initialize(tokenizer_json, tokenizer_config) ⇒ PreTrainedTokenizer
Returns a new instance of PreTrainedTokenizer.
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
# File 'lib/informers/tokenizers.rb', line 5 def initialize(tokenizer_json, tokenizer_config) super() @tokenizer_config = tokenizer_config @tokenizer = Tokenizers::Tokenizer.from_file(tokenizer_json) # Add added_tokens to model @special_tokens = [] @all_special_ids = [] @added_tokens = [] @tokenizer.added_tokens_decoder.each do |id, token| @added_tokens << token if token.special @special_tokens << token.content @all_special_ids << id end end # Update additional_special_tokens @additional_special_tokens = tokenizer_config["additional_special_tokens"] || [] @special_tokens.concat(@additional_special_tokens) @mask_token = get_token("mask_token") @mask_token_id = @tokenizer.token_to_id(@mask_token) if @mask_token @sep_token = get_token("sep_token") @sep_token_id = @tokenizer.token_to_id(@sep_token) if @sep_token @model_max_length = tokenizer_config["model_max_length"] # for donut-base-finetuned-docvqa if @model_max_length && @model_max_length > (1 << 63) @model_max_length = 1 << 63 end end |
Instance Attribute Details
#mask_token ⇒ Object (readonly)
Returns the value of attribute mask_token.
3 4 5 |
# File 'lib/informers/tokenizers.rb', line 3 def mask_token @mask_token end |
#mask_token_id ⇒ Object (readonly)
Returns the value of attribute mask_token_id.
3 4 5 |
# File 'lib/informers/tokenizers.rb', line 3 def mask_token_id @mask_token_id end |
#sep_token_id ⇒ Object (readonly)
Returns the value of attribute sep_token_id.
3 4 5 |
# File 'lib/informers/tokenizers.rb', line 3 def sep_token_id @sep_token_id end |
Instance Method Details
#batch_decode(batch, **decode_args) ⇒ Object
137 138 139 |
# File 'lib/informers/tokenizers.rb', line 137 def batch_decode(batch, **decode_args) @tokenizer.decode_batch(batch, **decode_args) end |
#call(text, text_pair: nil, add_special_tokens: true, padding: false, truncation: nil, max_length: nil, return_tensor: true, return_token_type_ids: true, return_offsets: false) ⇒ Object
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
# File 'lib/informers/tokenizers.rb', line 65 def call( text, text_pair: nil, add_special_tokens: true, padding: false, truncation: nil, max_length: nil, return_tensor: true, return_token_type_ids: true, # TODO change default return_offsets: false ) is_batched = text.is_a?(Array) if is_batched if text.length == 0 raise Error, "text array must be non-empty" end if !text_pair.nil? if !text_pair.is_a?(Array) raise Error, "text_pair must also be an array" elsif text.length != text_pair.length raise Error, "text and text_pair must have the same length" end end end if padding @tokenizer.enable_padding else @tokenizer.no_padding end if truncation @tokenizer.enable_truncation(max_length || @model_max_length) else @tokenizer.no_truncation end if is_batched input = text_pair ? text.zip(text_pair) : text encoded = @tokenizer.encode_batch(input, add_special_tokens: add_special_tokens) else encoded = [@tokenizer.encode(text, text_pair, add_special_tokens: add_special_tokens)] end result = {input_ids: encoded.map(&:ids), attention_mask: encoded.map(&:attention_mask)} if return_token_type_ids result[:token_type_ids] = encoded.map(&:type_ids) end if return_offsets result[:offsets] = encoded.map(&:offsets) end result end |
#convert_tokens_to_ids(tokens) ⇒ Object
129 130 131 |
# File 'lib/informers/tokenizers.rb', line 129 def convert_tokens_to_ids(tokens) tokens.map { |t| @tokenizer.token_to_id(t) } end |
#convert_tokens_to_string(tokens) ⇒ Object
125 126 127 |
# File 'lib/informers/tokenizers.rb', line 125 def convert_tokens_to_string(tokens) @tokenizer.decoder.decode(tokens) end |
#decode(tokens, skip_special_tokens:) ⇒ Object
121 122 123 |
# File 'lib/informers/tokenizers.rb', line 121 def decode(tokens, skip_special_tokens:) @tokenizer.decode(tokens, skip_special_tokens: skip_special_tokens) end |
#get_token(*keys) ⇒ Object
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
# File 'lib/informers/tokenizers.rb', line 44 def get_token(*keys) keys.each do |key| item = @tokenizer_config[key] if !item next end if item.is_a?(Hash) if item["__type"] == "AddedToken" return item["content"] else raise Error, "Unknown token: #{item}" end else return item end end nil end |
#id_to_token(id) ⇒ Object
133 134 135 |
# File 'lib/informers/tokenizers.rb', line 133 def id_to_token(id) @tokenizer.id_to_token(id) end |
#padding_side=(side) ⇒ Object
141 142 143 |
# File 'lib/informers/tokenizers.rb', line 141 def padding_side=(side) @tokenizer.enable_padding(direction: side) end |