Class: Transformers::EmbeddingPipeline
- Inherits:
-
Pipeline
- Object
- Pipeline
- Transformers::EmbeddingPipeline
show all
- Defined in:
- lib/transformers/pipelines/embedding.rb
Instance Method Summary
collapse
Methods inherited from Pipeline
#_ensure_tensor_on_device, #call, #check_model_type, #get_iterator, #initialize, #torch_dtype
Instance Method Details
#_forward(model_inputs) ⇒ Object
11
12
13
14
15
16
|
# File 'lib/transformers/pipelines/embedding.rb', line 11
def _forward(model_inputs)
{
last_hidden_state: @model.(**model_inputs)[0],
attention_mask: model_inputs[:attention_mask]
}
end
|
#_sanitize_parameters(**kwargs) ⇒ Object
3
4
5
|
# File 'lib/transformers/pipelines/embedding.rb', line 3
def _sanitize_parameters(**kwargs)
[{}, {}, kwargs]
end
|
#postprocess(model_outputs, pooling: "mean", normalize: true) ⇒ Object
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
|
# File 'lib/transformers/pipelines/embedding.rb', line 18
def postprocess(model_outputs, pooling: "mean", normalize: true)
output = model_outputs[:last_hidden_state]
case pooling
when "none"
when "mean"
output = mean_pooling(output, model_outputs[:attention_mask])
when "cls"
output = output[0.., 0]
else
raise Error, "Pooling method '#{pooling}' not supported."
end
if normalize
output = Torch::NN::Functional.normalize(output, p: 2, dim: 1)
end
output[0].to_a
end
|
#preprocess(inputs) ⇒ Object
7
8
9
|
# File 'lib/transformers/pipelines/embedding.rb', line 7
def preprocess(inputs)
@tokenizer.(inputs, return_tensors: @framework)
end
|