Class: Transformers::RerankingPipeline

Inherits:
Pipeline
  • Object
show all
Defined in:
lib/transformers/pipelines/reranking.rb

Instance Method Summary collapse

Methods inherited from Pipeline

#_ensure_tensor_on_device, #check_model_type, #get_iterator, #initialize, #torch_dtype

Constructor Details

This class inherits a constructor from Transformers::Pipeline

Instance Method Details

#_forward(model_inputs) ⇒ Object



16
17
18
19
# File 'lib/transformers/pipelines/reranking.rb', line 16

def _forward(model_inputs)
  model_outputs = @model.(**model_inputs)
  model_outputs
end

#_sanitize_parameters(**kwargs) ⇒ Object



3
4
5
# File 'lib/transformers/pipelines/reranking.rb', line 3

def _sanitize_parameters(**kwargs)
  [{}, {}, kwargs]
end

#call(query, documents) ⇒ Object



21
22
23
# File 'lib/transformers/pipelines/reranking.rb', line 21

def call(query, documents)
  super({query: query, documents: documents})
end

#postprocess(model_outputs) ⇒ Object



25
26
27
28
29
30
31
32
# File 'lib/transformers/pipelines/reranking.rb', line 25

def postprocess(model_outputs)
   model_outputs[0]
    .sigmoid
    .squeeze
    .to_a
    .map.with_index { |s, i| {index: i, score: s} }
    .sort_by { |v| -v[:score] }
end

#preprocess(inputs) ⇒ Object



7
8
9
10
11
12
13
14
# File 'lib/transformers/pipelines/reranking.rb', line 7

def preprocess(inputs)
  @tokenizer.(
    [inputs[:query]] * inputs[:documents].length,
    text_pair: inputs[:documents],
    return_tensors: @framework,
    padding: true
  )
end