Class: Rumale::ModelSelection::StratifiedShuffleSplit
- Inherits:
-
Object
- Object
- Rumale::ModelSelection::StratifiedShuffleSplit
- Includes:
- Base::Splitter
- Defined in:
- lib/rumale/model_selection/stratified_shuffle_split.rb
Overview
StratifiedShuffleSplit is a class that generates the set of data indices for random permutation cross-validation. The proportion of the number of samples in each class will be almost equal for each fold.
Instance Attribute Summary collapse
-
#n_splits ⇒ Integer
readonly
Return the number of folds.
-
#rng ⇒ Random
readonly
Return the random generator for shuffling the dataset.
Instance Method Summary collapse
-
#initialize(n_splits: 3, test_size: 0.1, train_size: nil, random_seed: nil) ⇒ StratifiedShuffleSplit
constructor
Create a new data splitter for random permutation cross validation.
-
#split(x, y) ⇒ Array
Generate data indices for stratified random permutation cross validation.
Constructor Details
#initialize(n_splits: 3, test_size: 0.1, train_size: nil, random_seed: nil) ⇒ StratifiedShuffleSplit
Create a new data splitter for random permutation cross validation.
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
# File 'lib/rumale/model_selection/stratified_shuffle_split.rb', line 35 def initialize(n_splits: 3, test_size: 0.1, train_size: nil, random_seed: nil) check_params_integer(n_splits: n_splits) check_params_float(test_size: test_size) check_params_type_or_nil(Float, train_size: train_size) check_params_type_or_nil(Integer, random_seed: random_seed) check_params_positive(n_splits: n_splits) check_params_positive(test_size: test_size) check_params_positive(train_size: train_size) unless train_size.nil? @n_splits = n_splits @test_size = test_size @train_size = train_size @random_seed = random_seed @random_seed ||= srand @rng = Random.new(@random_seed) end |
Instance Attribute Details
#n_splits ⇒ Integer (readonly)
Return the number of folds.
23 24 25 |
# File 'lib/rumale/model_selection/stratified_shuffle_split.rb', line 23 def n_splits @n_splits end |
#rng ⇒ Random (readonly)
Return the random generator for shuffling the dataset.
27 28 29 |
# File 'lib/rumale/model_selection/stratified_shuffle_split.rb', line 27 def rng @rng end |
Instance Method Details
#split(x, y) ⇒ Array
Generate data indices for stratified random permutation cross validation.
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
# File 'lib/rumale/model_selection/stratified_shuffle_split.rb', line 59 def split(x, y) check_sample_array(x) check_label_array(y) check_sample_label_size(x, y) # Initialize and check some variables. train_sz = @train_size.nil? ? 1.0 - @test_size : @train_size sub_rng = @rng.dup # Check the number of samples in each class. unless valid_n_splits?(y) raise ArgumentError, 'The value of n_splits must be not less than 1 and not more than the number of samples in each class.' end unless enough_data_size_each_class?(y, @test_size) raise RangeError, 'The number of sample in test split must be not less than 1 and not more than the number of samples in each class.' end unless enough_data_size_each_class?(y, train_sz) raise RangeError, 'The number of sample in train split must be not less than 1 and not more than the number of samples in each class.' end unless enough_data_size_each_class?(y, train_sz + @test_size) raise RangeError, 'The total number of samples in test split and train split must be not more than the number of samples in each class.' end # Returns array consisting of the training and testing ids for each fold. sample_ids_each_class = y.to_a.uniq.map { |label| y.eq(label).where.to_a } Array.new(@n_splits) do train_ids = [] test_ids = [] sample_ids_each_class.each do |sample_ids| n_samples = sample_ids.size n_test_samples = (@test_size * n_samples).to_i n_train_samples = (train_sz * n_samples).to_i test_ids += sample_ids.sample(n_test_samples, random: sub_rng) train_ids += if @train_size.nil? sample_ids - test_ids else (sample_ids - test_ids).sample(n_train_samples, random: sub_rng) end end [train_ids, test_ids] end end |