Class: TorchData::DataPipes::Iter::Util::InternalRandomSplitterIterDataPipe
- Inherits:
-
IterDataPipe
- Object
- TorchData::DataPipes::Iter::Util::InternalRandomSplitterIterDataPipe
- Defined in:
- lib/torchdata/data_pipes/iter/util/random_splitter.rb
Instance Attribute Summary collapse
-
#source_datapipe ⇒ Object
readonly
Returns the value of attribute source_datapipe.
Class Method Summary collapse
Instance Method Summary collapse
- #draw ⇒ Object
- #get_length(target) ⇒ Object
-
#initialize(source_datapipe, total_length, weights, seed) ⇒ InternalRandomSplitterIterDataPipe
constructor
A new instance of InternalRandomSplitterIterDataPipe.
- #override_seed(seed) ⇒ Object
- #reset ⇒ Object
Constructor Details
#initialize(source_datapipe, total_length, weights, seed) ⇒ InternalRandomSplitterIterDataPipe
Returns a new instance of InternalRandomSplitterIterDataPipe.
30 31 32 33 34 35 36 37 38 39 40 41 |
# File 'lib/torchdata/data_pipes/iter/util/random_splitter.rb', line 30 def initialize(source_datapipe, total_length, weights, seed) @source_datapipe = source_datapipe @total_length = total_length @remaining_length = @total_length @seed = seed @keys = weights.keys @key_to_index = @keys.map.with_index.to_h @norm_weights = self.class.normalize_weights(@keys.map { |k| weights[k] }, total_length) @weights = @norm_weights.dup @rng = Random.new(@seed) @lengths = [] end |
Instance Attribute Details
#source_datapipe ⇒ Object (readonly)
Returns the value of attribute source_datapipe.
28 29 30 |
# File 'lib/torchdata/data_pipes/iter/util/random_splitter.rb', line 28 def source_datapipe @source_datapipe end |
Class Method Details
.normalize_weights(weights, total_length) ⇒ Object
55 56 57 58 |
# File 'lib/torchdata/data_pipes/iter/util/random_splitter.rb', line 55 def self.normalize_weights(weights, total_length) total_weight = weights.sum weights.map { |w| w.to_f * total_length / total_weight } end |
Instance Method Details
#draw ⇒ Object
43 44 45 46 47 48 49 50 51 52 53 |
# File 'lib/torchdata/data_pipes/iter/util/random_splitter.rb', line 43 def draw selected_key = choices(@rng, @keys, @weights) index = @key_to_index[selected_key] @weights[index] -= 1 @remaining_length -= 1 if @weights[index] < 0 @weights[index] = 0 @weights = self.class.normalize_weights(@weights, @remaining_length) end selected_key end |
#get_length(target) ⇒ Object
71 72 73 |
# File 'lib/torchdata/data_pipes/iter/util/random_splitter.rb', line 71 def get_length(target) raise "todo" end |
#override_seed(seed) ⇒ Object
66 67 68 69 |
# File 'lib/torchdata/data_pipes/iter/util/random_splitter.rb', line 66 def override_seed(seed) @seed = seed self end |
#reset ⇒ Object
60 61 62 63 64 |
# File 'lib/torchdata/data_pipes/iter/util/random_splitter.rb', line 60 def reset @rng = Random.new(@seed) @weights = @norm_weights.dup @remaining_length = @total_length end |