Class: Rumale::ModelSelection::KFold
- Inherits:
-
Object
- Object
- Rumale::ModelSelection::KFold
- Includes:
- Base::Splitter
- Defined in:
- lib/rumale/model_selection/k_fold.rb
Overview
KFold is a class that generates the set of data indices for K-fold cross-validation.
Instance Attribute Summary collapse
-
#n_splits ⇒ Integer
readonly
Return the number of folds.
-
#rng ⇒ Random
readonly
Return the random generator for shuffling the dataset.
-
#shuffle ⇒ Boolean
readonly
Return the flag indicating whether to shuffle the dataset.
Instance Method Summary collapse
-
#initialize(n_splits: 3, shuffle: false, random_seed: nil) ⇒ KFold
constructor
Create a new data splitter for K-fold cross validation.
-
#split(x, _y = nil) ⇒ Array
Generate data indices for K-fold cross validation.
Constructor Details
#initialize(n_splits: 3, shuffle: false, random_seed: nil) ⇒ KFold
Create a new data splitter for K-fold cross validation.
38 39 40 41 42 43 44 45 46 47 48 |
# File 'lib/rumale/model_selection/k_fold.rb', line 38 def initialize(n_splits: 3, shuffle: false, random_seed: nil) check_params_integer(n_splits: n_splits) check_params_boolean(shuffle: shuffle) check_params_type_or_nil(Integer, random_seed: random_seed) check_params_positive(n_splits: n_splits) @n_splits = n_splits @shuffle = shuffle @random_seed = random_seed @random_seed ||= srand @rng = Random.new(@random_seed) end |
Instance Attribute Details
#n_splits ⇒ Integer (readonly)
Return the number of folds.
23 24 25 |
# File 'lib/rumale/model_selection/k_fold.rb', line 23 def n_splits @n_splits end |
#rng ⇒ Random (readonly)
Return the random generator for shuffling the dataset.
31 32 33 |
# File 'lib/rumale/model_selection/k_fold.rb', line 31 def rng @rng end |
#shuffle ⇒ Boolean (readonly)
Return the flag indicating whether to shuffle the dataset.
27 28 29 |
# File 'lib/rumale/model_selection/k_fold.rb', line 27 def shuffle @shuffle end |
Instance Method Details
#split(x, _y = nil) ⇒ Array
Generate data indices for K-fold cross validation.
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
# File 'lib/rumale/model_selection/k_fold.rb', line 55 def split(x, _y = nil) check_sample_array(x) # Initialize and check some variables. n_samples, = x.shape unless @n_splits.between?(2, n_samples) raise ArgumentError, 'The value of n_splits must be not less than 2 and not more than the number of samples.' end sub_rng = @rng.dup # Splits dataset ids to each fold. dataset_ids = [*0...n_samples] dataset_ids.shuffle!(random: sub_rng) if @shuffle fold_sets = Array.new(@n_splits) do |n| n_fold_samples = n_samples / @n_splits n_fold_samples += 1 if n < n_samples % @n_splits dataset_ids.shift(n_fold_samples) end # Returns array consisting of the training and testing ids for each fold. Array.new(@n_splits) do |n| train_ids = fold_sets.select.with_index { |_, id| id != n }.flatten test_ids = fold_sets[n] [train_ids, test_ids] end end |