Module: Torch::Utils::Data

Defined in:
lib/torch/utils/data.rb,
lib/torch/utils/data/subset.rb,
lib/torch/utils/data/dataset.rb,
lib/torch/utils/data/data_loader.rb,
lib/torch/utils/data/tensor_dataset.rb

Defined Under Namespace

Classes: DataLoader, Dataset, Subset, TensorDataset

Class Method Summary collapse

Class Method Details

.random_split(dataset, lengths) ⇒ Object



5
6
7
8
9
10
11
12
# File 'lib/torch/utils/data.rb', line 5

def random_split(dataset, lengths)
  if lengths.sum != dataset.length
    raise ArgumentError, "Sum of input lengths does not equal the length of the input dataset!"
  end

  indices = Torch.randperm(lengths.sum).to_a
  _accumulate(lengths).zip(lengths).map { |offset, length| Subset.new(dataset, indices[(offset - length)...offset]) }
end