Class: SVMKit::ModelSelection::KFold
- Inherits:
-
Object
- Object
- SVMKit::ModelSelection::KFold
- Includes:
- Base::Splitter
- Defined in:
- lib/svmkit/model_selection/k_fold.rb
Overview
KFold is a class that generates the set of data indices for K-fold cross-validation.
Instance Attribute Summary collapse
-
#rng ⇒ Random
readonly
Return the random generator for shuffling the dataset.
-
#shuffle ⇒ Boolean
readonly
Return the flag indicating whether to shuffle the dataset.
Attributes included from Base::Splitter
Instance Method Summary collapse
-
#initialize(n_splits: 3, shuffle: false, random_seed: nil) ⇒ KFold
constructor
Create a new data splitter for K-fold cross validation.
-
#split(x, y) ⇒ Array
Generate data indices for K-fold cross validation.
Constructor Details
#initialize(n_splits: 3, shuffle: false, random_seed: nil) ⇒ KFold
Create a new data splitter for K-fold cross validation.
34 35 36 37 38 39 40 |
# File 'lib/svmkit/model_selection/k_fold.rb', line 34 def initialize(n_splits: 3, shuffle: false, random_seed: nil) @n_splits = n_splits @shuffle = shuffle @random_seed = random_seed @random_seed ||= srand @rng = Random.new(@random_seed) end |
Instance Attribute Details
#rng ⇒ Random (readonly)
Return the random generator for shuffling the dataset.
27 28 29 |
# File 'lib/svmkit/model_selection/k_fold.rb', line 27 def rng @rng end |
#shuffle ⇒ Boolean (readonly)
Return the flag indicating whether to shuffle the dataset.
23 24 25 |
# File 'lib/svmkit/model_selection/k_fold.rb', line 23 def shuffle @shuffle end |
Instance Method Details
#split(x, y) ⇒ Array
Generate data indices for K-fold cross validation.
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
# File 'lib/svmkit/model_selection/k_fold.rb', line 50 def split(x, y) # rubocop:disable Lint/UnusedMethodArgument # Initialize and check some variables. n_samples, = x.shape unless @n_splits.between?(2, n_samples) raise ArgumentError, 'The value of n_splits must be not less than 2 and not more than the number of samples.' end # Splits dataset ids to each fold. dataset_ids = [*0...n_samples] dataset_ids.shuffle!(random: @rng) if @shuffle fold_sets = Array.new(@n_splits) do |n| n_fold_samples = n_samples / @n_splits n_fold_samples += 1 if n < n_samples % @n_splits dataset_ids.shift(n_fold_samples) end # Returns array consisting of the training and testing ids for each fold. Array.new(@n_splits) do |n| train_ids = fold_sets.select.with_index { |_, id| id != n }.flatten test_ids = fold_sets[n] [train_ids, test_ids] end end |