Class: TorchData::DataPipes::Iter::Util::RandomSplitter

Inherits:
IterDataPipe
  • Object
show all
Defined in:
lib/torchdata/data_pipes/iter/util/random_splitter.rb

Class Method Summary collapse

Class Method Details

.new(source_datapipe, weights:, seed:, total_length: nil, target: nil) ⇒ Object



8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# File 'lib/torchdata/data_pipes/iter/util/random_splitter.rb', line 8

def self.new(source_datapipe, weights:, seed:, total_length: nil, target: nil)
  if total_length.nil?
    begin
      total_length = source_datapipe.length
    rescue NoMethodError
      raise TypeError, "RandomSplitter needs `total_length`, but it is unable to infer it from the `source_datapipe`: #{source_datapipe}."
    end
  end

  container = InternalRandomSplitterIterDataPipe.new(source_datapipe, total_length, weights, seed)

  if target.nil?
    weights.map { |k, _| SplitterIterator.new(container, k) }
  else
    raise "todo"
  end
end