Class: Transformers::EmbeddingPipeline

Inherits:
Pipeline
  • Object
show all
Defined in:
lib/transformers/pipelines/embedding.rb

Instance Method Summary collapse

Methods inherited from Pipeline

#_ensure_tensor_on_device, #call, #check_model_type, #get_iterator, #initialize, #torch_dtype

Constructor Details

This class inherits a constructor from Transformers::Pipeline

Instance Method Details

#_forward(model_inputs) ⇒ Object



11
12
13
14
15
16
# File 'lib/transformers/pipelines/embedding.rb', line 11

def _forward(model_inputs)
  {
    last_hidden_state: @model.(**model_inputs)[0],
    attention_mask: model_inputs[:attention_mask]
  }
end

#_sanitize_parameters(**kwargs) ⇒ Object



3
4
5
# File 'lib/transformers/pipelines/embedding.rb', line 3

def _sanitize_parameters(**kwargs)
  [{}, {}, kwargs]
end

#postprocess(model_outputs, pooling: "mean", normalize: true) ⇒ Object



18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# File 'lib/transformers/pipelines/embedding.rb', line 18

def postprocess(model_outputs, pooling: "mean", normalize: true)
  output = model_outputs[:last_hidden_state]

  case pooling
  when "none"
    # do nothing
  when "mean"
    output = mean_pooling(output, model_outputs[:attention_mask])
  when "cls"
    output = output[0.., 0]
  else
    raise Error, "Pooling method '#{pooling}' not supported."
  end

  if normalize
    output = Torch::NN::Functional.normalize(output, p: 2, dim: 1)
  end

  output[0].to_a
end

#preprocess(inputs) ⇒ Object



7
8
9
# File 'lib/transformers/pipelines/embedding.rb', line 7

def preprocess(inputs)
  @tokenizer.(inputs, return_tensors: @framework)
end