Class: Transformers::RerankingPipeline
- Inherits:
-
Pipeline
- Object
- Pipeline
- Transformers::RerankingPipeline
show all
- Defined in:
- lib/transformers/pipelines/reranking.rb
Instance Method Summary
collapse
Methods inherited from Pipeline
#_ensure_tensor_on_device, #check_model_type, #get_iterator, #initialize, #torch_dtype
Instance Method Details
#_forward(model_inputs) ⇒ Object
16
17
18
19
|
# File 'lib/transformers/pipelines/reranking.rb', line 16
def _forward(model_inputs)
model_outputs = @model.(**model_inputs)
model_outputs
end
|
#_sanitize_parameters(**kwargs) ⇒ Object
3
4
5
|
# File 'lib/transformers/pipelines/reranking.rb', line 3
def _sanitize_parameters(**kwargs)
[{}, {}, kwargs]
end
|
#call(query, documents) ⇒ Object
21
22
23
|
# File 'lib/transformers/pipelines/reranking.rb', line 21
def call(query, documents)
super({query: query, documents: documents})
end
|
#postprocess(model_outputs) ⇒ Object
25
26
27
28
29
30
31
32
|
# File 'lib/transformers/pipelines/reranking.rb', line 25
def postprocess(model_outputs)
model_outputs[0]
.sigmoid
.squeeze
.to_a
.map.with_index { |s, i| {index: i, score: s} }
.sort_by { |v| -v[:score] }
end
|
#preprocess(inputs) ⇒ Object
7
8
9
10
11
12
13
14
|
# File 'lib/transformers/pipelines/reranking.rb', line 7
def preprocess(inputs)
@tokenizer.(
[inputs[:query]] * inputs[:documents].length,
text_pair: inputs[:documents],
return_tensors: @framework,
padding: true
)
end
|