Class: Rumale::ModelSelection::StratifiedShuffleSplit

Inherits:
Object
  • Object
show all
Includes:
Base::Splitter
Defined in:
lib/rumale/model_selection/stratified_shuffle_split.rb

Overview

StratifiedShuffleSplit is a class that generates the set of data indices for random permutation cross-validation. The proportion of the number of samples in each class will be almost equal for each fold.

Examples:

ss = Rumale::ModelSelection::StratifiedShuffleSplit.new(n_splits: 3, test_size: 0.2, random_seed: 1)
ss.split(samples, labels).each do |train_ids, test_ids|
  train_samples = samples[train_ids, true]
  test_samples = samples[test_ids, true]
  ...
end

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(n_splits: 3, test_size: 0.1, train_size: nil, random_seed: nil) ⇒ StratifiedShuffleSplit

Create a new data splitter for random permutation cross validation.

Parameters:

  • n_splits (Integer) (defaults to: 3)

    The number of folds.

  • test_size (Float) (defaults to: 0.1)

    The ratio of number of samples for test data.

  • train_size (Float) (defaults to: nil)

    The ratio of number of samples for train data.

  • random_seed (Integer) (defaults to: nil)

    The seed value using to initialize the random generator.



35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# File 'lib/rumale/model_selection/stratified_shuffle_split.rb', line 35

def initialize(n_splits: 3, test_size: 0.1, train_size: nil, random_seed: nil)
  check_params_integer(n_splits: n_splits)
  check_params_float(test_size: test_size)
  check_params_type_or_nil(Float, train_size: train_size)
  check_params_type_or_nil(Integer, random_seed: random_seed)
  check_params_positive(n_splits: n_splits)
  check_params_positive(test_size: test_size)
  check_params_positive(train_size: train_size) unless train_size.nil?
  @n_splits = n_splits
  @test_size = test_size
  @train_size = train_size
  @random_seed = random_seed
  @random_seed ||= srand
  @rng = Random.new(@random_seed)
end

Instance Attribute Details

#n_splitsInteger (readonly)

Return the number of folds.

Returns:

  • (Integer)


23
24
25
# File 'lib/rumale/model_selection/stratified_shuffle_split.rb', line 23

def n_splits
  @n_splits
end

#rngRandom (readonly)

Return the random generator for shuffling the dataset.

Returns:

  • (Random)


27
28
29
# File 'lib/rumale/model_selection/stratified_shuffle_split.rb', line 27

def rng
  @rng
end

Instance Method Details

#split(x, y) ⇒ Array

Generate data indices for stratified random permutation cross validation.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The dataset to be used to generate data indices for stratified random permutation cross validation. This argument exists to unify the interface between the K-fold methods, it is not used in the method.

  • y (Numo::Int32)

    (shape: [n_samples]) The labels to be used to generate data indices for stratified random permutation cross validation.

Returns:

  • (Array)

    The set of data indices for constructing the training and testing dataset in each fold.



59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# File 'lib/rumale/model_selection/stratified_shuffle_split.rb', line 59

def split(x, y)
  check_sample_array(x)
  check_label_array(y)
  check_sample_label_size(x, y)
  # Initialize and check some variables.
  train_sz = @train_size.nil? ? 1.0 - @test_size : @train_size
  sub_rng = @rng.dup
  # Check the number of samples in each class.
  unless valid_n_splits?(y)
    raise ArgumentError,
          'The value of n_splits must be not less than 1 and not more than the number of samples in each class.'
  end
  unless enough_data_size_each_class?(y, @test_size)
    raise RangeError,
          'The number of sample in test split must be not less than 1 and not more than the number of samples in each class.'
  end
  unless enough_data_size_each_class?(y, train_sz)
    raise RangeError,
          'The number of sample in train split must be not less than 1 and not more than the number of samples in each class.'
  end
  unless enough_data_size_each_class?(y, train_sz + @test_size)
    raise RangeError,
          'The total number of samples in test split and train split must be not more than the number of samples in each class.'
  end
  # Returns array consisting of the training and testing ids for each fold.
  sample_ids_each_class = y.to_a.uniq.map { |label| y.eq(label).where.to_a }
  Array.new(@n_splits) do
    train_ids = []
    test_ids = []
    sample_ids_each_class.each do |sample_ids|
      n_samples = sample_ids.size
      n_test_samples = (@test_size * n_samples).to_i
      n_train_samples = (train_sz * n_samples).to_i
      test_ids += sample_ids.sample(n_test_samples, random: sub_rng)
      train_ids += if @train_size.nil?
                     sample_ids - test_ids
                   else
                     (sample_ids - test_ids).sample(n_train_samples, random: sub_rng)
                   end
    end
    [train_ids, test_ids]
  end
end