Class: TorchData::DataPipes::Iter::Util::InternalRandomSplitterIterDataPipe

Inherits:
IterDataPipe
  • Object
show all
Defined in:
lib/torchdata/data_pipes/iter/util/random_splitter.rb

Instance Attribute Summary collapse

Class Method Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(source_datapipe, total_length, weights, seed) ⇒ InternalRandomSplitterIterDataPipe

Returns a new instance of InternalRandomSplitterIterDataPipe.



30
31
32
33
34
35
36
37
38
39
40
41
# File 'lib/torchdata/data_pipes/iter/util/random_splitter.rb', line 30

def initialize(source_datapipe, total_length, weights, seed)
  @source_datapipe = source_datapipe
  @total_length = total_length
  @remaining_length = @total_length
  @seed = seed
  @keys = weights.keys
  @key_to_index = @keys.map.with_index.to_h
  @norm_weights = self.class.normalize_weights(@keys.map { |k| weights[k] }, total_length)
  @weights = @norm_weights.dup
  @rng = Random.new(@seed)
  @lengths = []
end

Instance Attribute Details

#source_datapipeObject (readonly)

Returns the value of attribute source_datapipe.



28
29
30
# File 'lib/torchdata/data_pipes/iter/util/random_splitter.rb', line 28

def source_datapipe
  @source_datapipe
end

Class Method Details

.normalize_weights(weights, total_length) ⇒ Object



55
56
57
58
# File 'lib/torchdata/data_pipes/iter/util/random_splitter.rb', line 55

def self.normalize_weights(weights, total_length)
  total_weight = weights.sum
  weights.map { |w| w.to_f * total_length / total_weight }
end

Instance Method Details

#drawObject



43
44
45
46
47
48
49
50
51
52
53
# File 'lib/torchdata/data_pipes/iter/util/random_splitter.rb', line 43

def draw
  selected_key = choices(@rng, @keys, @weights)
  index = @key_to_index[selected_key]
  @weights[index] -= 1
  @remaining_length -= 1
  if @weights[index] < 0
    @weights[index] = 0
    @weights = self.class.normalize_weights(@weights, @remaining_length)
  end
  selected_key
end

#get_length(target) ⇒ Object



71
72
73
# File 'lib/torchdata/data_pipes/iter/util/random_splitter.rb', line 71

def get_length(target)
  raise "todo"
end

#override_seed(seed) ⇒ Object



66
67
68
69
# File 'lib/torchdata/data_pipes/iter/util/random_splitter.rb', line 66

def override_seed(seed)
  @seed = seed
  self
end

#resetObject



60
61
62
63
64
# File 'lib/torchdata/data_pipes/iter/util/random_splitter.rb', line 60

def reset
  @rng = Random.new(@seed)
  @weights = @norm_weights.dup
  @remaining_length = @total_length
end