Class: Transformers::TokenClassificationPipeline

Inherits:
ChunkPipeline show all
Extended by:
ClassAttribute
Defined in:
lib/transformers/pipelines/token_classification.rb

Instance Method Summary collapse

Methods included from ClassAttribute

class_attribute

Methods inherited from ChunkPipeline

#run_single

Methods inherited from Pipeline

#_ensure_tensor_on_device, #call, #check_model_type, #get_iterator, #torch_dtype

Constructor Details

#initialize(*args, args_parser: TokenClassificationArgumentHandler.new, **kwargs) ⇒ TokenClassificationPipeline

Returns a new instance of TokenClassificationPipeline.



18
19
20
21
22
23
24
# File 'lib/transformers/pipelines/token_classification.rb', line 18

def initialize(*args, args_parser: TokenClassificationArgumentHandler.new, **kwargs)
  super(*args, **kwargs)
  check_model_type(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES)

  @basic_tokenizer = Bert::BertTokenizer::BasicTokenizer.new(do_lower_case: false)
  @args_parser = args_parser
end

Instance Method Details

#_forward(model_inputs) ⇒ Object



137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# File 'lib/transformers/pipelines/token_classification.rb', line 137

def _forward(model_inputs)
  # Forward
  special_tokens_mask = model_inputs.delete(:special_tokens_mask)
  offset_mapping = model_inputs.delete(:offset_mapping)
  sentence = model_inputs.delete(:sentence)
  is_last = model_inputs.delete(:is_last)
  if @framework == "tf"
    logits = @model.(**model_inputs)[0]
  else
    output = @model.(**model_inputs)
    logits = output.is_a?(Hash) ? output[:logits] : output[0]
  end

  {
    logits: logits,
    special_tokens_mask: special_tokens_mask,
    offset_mapping: offset_mapping,
    sentence: sentence,
    is_last: is_last,
    **model_inputs
  }
end

#_sanitize_parameters(ignore_labels: nil, grouped_entities: nil, ignore_subwords: nil, aggregation_strategy: nil, offset_mapping: nil, stride: nil) ⇒ Object



26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# File 'lib/transformers/pipelines/token_classification.rb', line 26

def _sanitize_parameters(
  ignore_labels: nil,
  grouped_entities: nil,
  ignore_subwords: nil,
  aggregation_strategy: nil,
  offset_mapping: nil,
  stride: nil
)
  preprocess_params = {}
  if !offset_mapping.nil?
    preprocess_params[:offset_mapping] = offset_mapping
  end

  postprocess_params = {}
  if !grouped_entities.nil? || !ignore_subwords.nil?
    if grouped_entities && ignore_subwords
      aggregation_strategy = AggregationStrategy::FIRST
    elsif grouped_entities && !ignore_subwords
      aggregation_strategy = AggregationStrategy::SIMPLE
    else
      aggregation_strategy = AggregationStrategy::NONE
    end

    if !grouped_entities.nil?
      warn(
        "`grouped_entities` is deprecated and will be removed in version v5.0.0, defaulted to" +
        " `aggregation_strategy=\"#{aggregation_strategy}\"` instead."
      )
    end
    if !ignore_subwords.nil?
      warn(
        "`ignore_subwords` is deprecated and will be removed in version v5.0.0, defaulted to" +
        " `aggregation_strategy=\"#{aggregation_strategy}\"` instead."
      )
    end
  end

  if !aggregation_strategy.nil?
    if aggregation_strategy.is_a?(String)
      aggregation_strategy = AggregationStrategy.new(aggregation_strategy.downcase).to_s
    end
    if (
      [AggregationStrategy::FIRST, AggregationStrategy::MAX, AggregationStrategy::AVERAGE].include?(aggregation_strategy) &&
      !@tokenizer.is_fast
    )
      raise ArgumentError,
        "Slow tokenizers cannot handle subwords. Please set the `aggregation_strategy` option" +
        ' to `"simple"` or use a fast tokenizer.'
    end
    postprocess_params[:aggregation_strategy] = aggregation_strategy
  end
  if !ignore_labels.nil?
    postprocess_params[:ignore_labels] = ignore_labels
  end
  if !stride.nil?
    if stride >= @tokenizer.model_max_length
      raise ArgumentError,
        "`stride` must be less than `tokenizer.model_max_length` (or even lower if the tokenizer adds special tokens)"
    end
    if aggregation_strategy == AggregationStrategy::NONE
      raise ArgumentError,
        "`stride` was provided to process all the text but `aggregation_strategy=" +
        "\"#{aggregation_strategy}\"`, please select another one instead."
    else
      if @tokenizer.is_fast
        tokenizer_params = {
          return_overflowing_tokens: true,
          padding: true,
          stride: stride
        }
        preprocess_params[:tokenizer_params] = tokenizer_params
      else
        raise ArgumentError,
          "`stride` was provided to process all the text but you're using a slow tokenizer." +
          " Please use a fast tokenizer."
      end
    end
  end
  [preprocess_params, {}, postprocess_params]
end

#aggregate(pre_entities, aggregation_strategy) ⇒ Object



256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
# File 'lib/transformers/pipelines/token_classification.rb', line 256

def aggregate(pre_entities, aggregation_strategy)
  if [AggregationStrategy::NONE, AggregationStrategy::SIMPLE].include?(aggregation_strategy)
    entities = []
    pre_entities.each do |pre_entity|
      entity_idx = pre_entity[:scores].argmax
      score = pre_entity[:scores][entity_idx]
      entity = {
        entity: @model.config.id2label[entity_idx],
        score: score,
        index: pre_entity[:index],
        word: pre_entity[:word],
        start: pre_entity[:start],
        end: pre_entity[:end]
      }
      entities << entity
    end
  else
    entities = aggregate_words(pre_entities, aggregation_strategy)
  end

  if aggregation_strategy == AggregationStrategy::NONE
    return entities
  end
  group_entities(entities)
end

#aggregate_word(entities, aggregation_strategy) ⇒ Object

Raises:



282
283
284
# File 'lib/transformers/pipelines/token_classification.rb', line 282

def aggregate_word(entities, aggregation_strategy)
  raise Todo
end

#aggregate_words(entities, aggregation_strategy) ⇒ Object

Raises:



286
287
288
# File 'lib/transformers/pipelines/token_classification.rb', line 286

def aggregate_words(entities, aggregation_strategy)
  raise Todo
end

#gather_pre_entities(sentence, input_ids, scores, offset_mapping, special_tokens_mask, aggregation_strategy) ⇒ Object



200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
# File 'lib/transformers/pipelines/token_classification.rb', line 200

def gather_pre_entities(
  sentence,
  input_ids,
  scores,
  offset_mapping,
  special_tokens_mask,
  aggregation_strategy
)
  pre_entities = []
  scores.each_over_axis(0).with_index do |token_scores, idx|
    # Filter special_tokens
    if special_tokens_mask[idx] != 0
      next
    end

    word = @tokenizer.convert_ids_to_tokens(input_ids[idx].to_i)
    if !offset_mapping.nil?
      start_ind, end_ind = offset_mapping[idx].to_a
      if !start_ind.is_a?(Integer)
        if @framework == "pt"
          start_ind = start_ind.item
          end_ind = end_ind.item
        end
      end
      word_ref = sentence[start_ind...end_ind]
      if @tokenizer.instance_variable_get(:@tokenizer).respond_to?(:continuing_subword_prefix)
        # This is a BPE, word aware tokenizer, there is a correct way
        # to fuse tokens
        is_subword = word.length != word_ref.length
      else
        is_subword = start_ind > 0 && !sentence[(start_ind - 1)...(start_ind + 1)].include?(" ")
      end

      if input_ids[idx].to_i == @tokenizer.unk_token_id
        word = word_ref
        is_subword = false
      end
    else
      start_ind = nil
      end_ind = nil
      is_subword = nil
    end

    pre_entity = {
      word: word,
      scores: token_scores,
      start: start_ind,
      end: end_ind,
      index: idx,
      is_subword: is_subword
    }
    pre_entities << pre_entity
  end
  pre_entities
end

#get_tag(entity_name) ⇒ Object



306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
# File 'lib/transformers/pipelines/token_classification.rb', line 306

def get_tag(entity_name)
  if entity_name.start_with?("B-")
    bi = "B"
    tag = entity_name[2..]
  elsif entity_name.start_with?("I-")
    bi = "I"
    tag = entity_name[2..]
  else
    # It's not in B-, I- format
    # Default to I- for continuation.
    bi = "I"
    tag = entity_name
  end
  [bi, tag]
end

#group_entities(entities) ⇒ Object



322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
# File 'lib/transformers/pipelines/token_classification.rb', line 322

def group_entities(entities)
  entity_groups = []
  entity_group_disagg = []

  entities.each do |entity|
    if entity_group_disagg.empty?
      entity_group_disagg << entity
      next
    end

    # If the current entity is similar and adjacent to the previous entity,
    # append it to the disaggregated entity group
    # The split is meant to account for the "B" and "I" prefixes
    # Shouldn't merge if both entities are B-type
    bi, tag = get_tag(entity[:entity])
    _last_bi, last_tag = get_tag(entity_group_disagg[-1][:entity])

    if tag == last_tag && bi != "B"
      # Modify subword type to be previous_type
      entity_group_disagg << entity
    else
      # If the current entity is different from the previous entity
      # aggregate the disaggregated entity group
      entity_groups << group_sub_entities(entity_group_disagg)
      entity_group_disagg = [entity]
    end
  end
  if entity_group_disagg.any?
    # it's the last entity, add it to the entity groups
    entity_groups << group_sub_entities(entity_group_disagg)
  end

  entity_groups
end

#group_sub_entities(entities) ⇒ Object



290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
# File 'lib/transformers/pipelines/token_classification.rb', line 290

def group_sub_entities(entities)
  # Get the first entity in the entity group
  entity = entities[0][:entity].split("-", 2)[-1]
  scores = entities.map { |entity| entity[:score] }
  tokens = entities.map { |entity| entity[:word] }

  entity_group = {
    entity_group: entity,
    score: scores.sum / scores.count.to_f,
    word: @tokenizer.convert_tokens_to_string(tokens),
    start: entities[0][:start],
    end: entities[-1][:end]
  }
  entity_group
end

#postprocess(all_outputs, aggregation_strategy: AggregationStrategy::NONE, ignore_labels: nil) ⇒ Object



160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
# File 'lib/transformers/pipelines/token_classification.rb', line 160

def postprocess(all_outputs, aggregation_strategy: AggregationStrategy::NONE, ignore_labels: nil)
  if ignore_labels.nil?
    ignore_labels = ["O"]
  end
  all_entities = []
  all_outputs.each do |model_outputs|
    logits = model_outputs[:logits][0].numo
    sentence = all_outputs[0][:sentence]
    input_ids = model_outputs[:input_ids][0]
    offset_mapping = (
      !model_outputs[:offset_mapping].nil? ? model_outputs[:offset_mapping][0] : nil
    )
    special_tokens_mask = model_outputs[:special_tokens_mask][0].numo

    maxes = logits.max(axis: -1).expand_dims(-1)
    shifted_exp = Numo::NMath.exp(logits - maxes)
    scores = shifted_exp / shifted_exp.sum(axis: -1).expand_dims(-1)

    if @framework == "tf"
      raise Todo
    end

    pre_entities = gather_pre_entities(
      sentence, input_ids, scores, offset_mapping, special_tokens_mask, aggregation_strategy
    )
    grouped_entities = aggregate(pre_entities, aggregation_strategy)
    # Filter anything that is in self.ignore_labels
    entities =
      grouped_entities.select do |entity|
        !ignore_labels.include?(entity[:entity]) && !ignore_labels.include?(entity[:entity_group])
      end
    all_entities.concat(entities)
  end
  num_chunks = all_outputs.length
  if num_chunks > 1
    all_entities = aggregate_overlapping_entities(all_entities)
  end
  all_entities
end

#preprocess(sentence, offset_mapping: nil, **preprocess_params) ⇒ Object



107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# File 'lib/transformers/pipelines/token_classification.rb', line 107

def preprocess(sentence, offset_mapping: nil, **preprocess_params)
  tokenizer_params = preprocess_params.delete(:tokenizer_params) { {} }
  truncation = @tokenizer.model_max_length && @tokenizer.model_max_length > 0
  inputs = @tokenizer.(
    sentence,
    return_tensors: @framework,
    truncation: truncation,
    return_special_tokens_mask: true,
    return_offsets_mapping: @tokenizer.is_fast,
    **tokenizer_params
  )
  inputs.delete(:overflow_to_sample_mapping)
  num_chunks = inputs[:input_ids].length

  num_chunks.times do |i|
    if @framework == "tf"
      raise Todo
    else
      model_inputs = inputs.to_h { |k, v| [k, v[i].unsqueeze(0)] }
    end
    if !@offset_mapping.nil?
      model_inputs[:offset_mapping] = offset_mapping
    end
    model_inputs[:sentence] = i == 0 ? sentence : nil
    model_inputs[:is_last] = (i == num_chunks - 1)

    yield model_inputs
  end
end