Class: Torch::Utils::Data::TensorDataset

Inherits:
Dataset
  • Object
show all
Defined in:
lib/torch/utils/data/tensor_dataset.rb

Instance Method Summary collapse

Constructor Details

#initialize(*tensors) ⇒ TensorDataset

Returns a new instance of TensorDataset.



5
6
7
8
9
10
# File 'lib/torch/utils/data/tensor_dataset.rb', line 5

def initialize(*tensors)
  unless tensors.all? { |t| t.size(0) == tensors[0].size(0) }
    raise Error, "Tensors must all have same dim 0 size"
  end
  @tensors = tensors
end

Instance Method Details

#[](index) ⇒ Object



12
13
14
# File 'lib/torch/utils/data/tensor_dataset.rb', line 12

def [](index)
  @tensors.map { |t| t[index] }
end

#sizeObject Also known as: length, count



16
17
18
# File 'lib/torch/utils/data/tensor_dataset.rb', line 16

def size
  @tensors[0].size(0)
end