Class: Torch::Utils::Data::DataLoader
- Inherits:
-
Object
- Object
- Torch::Utils::Data::DataLoader
- Includes:
- Enumerable
- Defined in:
- lib/torch/utils/data/data_loader.rb
Instance Attribute Summary collapse
-
#dataset ⇒ Object
readonly
Returns the value of attribute dataset.
Instance Method Summary collapse
- #each ⇒ Object
-
#initialize(dataset, batch_size: 1, shuffle: false, collate_fn: nil) ⇒ DataLoader
constructor
A new instance of DataLoader.
- #size ⇒ Object (also: #length, #count)
Constructor Details
#initialize(dataset, batch_size: 1, shuffle: false, collate_fn: nil) ⇒ DataLoader
Returns a new instance of DataLoader.
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
# File 'lib/torch/utils/data/data_loader.rb', line 9 def initialize(dataset, batch_size: 1, shuffle: false, collate_fn: nil) @dataset = dataset @batch_size = batch_size @shuffle = shuffle @batch_sampler = nil if collate_fn.nil? if auto_collation? collate_fn = method(:default_collate) else collate_fn = method(:default_convert) end end @collate_fn = collate_fn end |
Instance Attribute Details
#dataset ⇒ Object (readonly)
Returns the value of attribute dataset.
7 8 9 |
# File 'lib/torch/utils/data/data_loader.rb', line 7 def dataset @dataset end |
Instance Method Details
#each ⇒ Object
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
# File 'lib/torch/utils/data/data_loader.rb', line 27 def each # try to keep the random number generator in sync with Python # this makes it easy to compare results base_seed = Torch.empty([], dtype: :int64).random!.item indexes = if @shuffle Torch.randperm(@dataset.size).to_a else @dataset.size.times end indexes.each_slice(@batch_size) do |idx| # TODO improve performance yield @collate_fn.call(idx.map { |i| @dataset[i] }) end end |
#size ⇒ Object Also known as: length, count
45 46 47 |
# File 'lib/torch/utils/data/data_loader.rb', line 45 def size (@dataset.size / @batch_size.to_f).ceil end |