Class: Transformers::PipelineIterator

Inherits:
Torch::Utils::Data::IterableDataset
  • Object
show all
Defined in:
lib/transformers/pipelines/pt_utils.rb

Instance Method Summary collapse

Constructor Details

#initialize(loader, infer, params, loader_batch_size: nil) ⇒ PipelineIterator

Returns a new instance of PipelineIterator.



21
22
23
24
25
26
27
28
29
30
31
32
33
34
# File 'lib/transformers/pipelines/pt_utils.rb', line 21

def initialize(loader, infer, params, loader_batch_size: nil)
  @loader = loader
  @infer = infer
  @params = params
  if loader_batch_size == 1
    # Let's spare some time by deactivating altogether
    loader_batch_size = nil
  end
  @loader_batch_size = loader_batch_size

  # Internal bookkeeping
  @loader_batch_index = nil
  @loader_batch_data = nil
end

Instance Method Details

#[](i) ⇒ Object



40
41
42
# File 'lib/transformers/pipelines/pt_utils.rb', line 40

def [](i)
  @infer.(@loader[i], **@params)
end

#eachObject



44
45
46
47
48
49
50
51
# File 'lib/transformers/pipelines/pt_utils.rb', line 44

def each
  @iterator = @loader

  @iterator.each do |item|
    processed = @infer.(item, **@params)
    yield processed
  end
end

#sizeObject



36
37
38
# File 'lib/transformers/pipelines/pt_utils.rb', line 36

def size
  @loader.size
end