8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
|
# File 'lib/torchdata/data_pipes/iter/util/random_splitter.rb', line 8
def self.new(source_datapipe, weights:, seed:, total_length: nil, target: nil)
if total_length.nil?
begin
total_length = source_datapipe.length
rescue NoMethodError
raise TypeError, "RandomSplitter needs `total_length`, but it is unable to infer it from the `source_datapipe`: #{source_datapipe}."
end
end
container = InternalRandomSplitterIterDataPipe.new(source_datapipe, total_length, weights, seed)
if target.nil?
weights.map { |k, _| SplitterIterator.new(container, k) }
else
raise "todo"
end
end
|