Class: SVMKit::ModelSelection::KFold

Inherits:
Object
  • Object
show all
Includes:
Base::Splitter
Defined in:
lib/svmkit/model_selection/k_fold.rb

Overview

KFold is a class that generates the set of data indices for K-fold cross-validation.

Examples:

kf = SVMKit::ModelSelection::KFold.new(n_splits: 3, shuffle: true, random_seed: 1)
kf.split(samples, labels).each do |train_ids, test_ids|
  train_samples = samples[train_ids, true]
  test_samples = samples[test_ids, true]
  ...
end

Instance Attribute Summary collapse

Attributes included from Base::Splitter

#n_splits

Instance Method Summary collapse

Constructor Details

#initialize(n_splits: 3, shuffle: false, random_seed: nil) ⇒ KFold

Create a new data splitter for K-fold cross validation.

Parameters:

  • n_splits (Integer) (defaults to: 3)

    The number of folds.

  • shuffle (Boolean) (defaults to: false)

    The flag indicating whether to shuffle the dataset.

  • random_seed (Integer) (defaults to: nil)

    The seed value using to initialize the random generator.



34
35
36
37
38
39
40
# File 'lib/svmkit/model_selection/k_fold.rb', line 34

def initialize(n_splits: 3, shuffle: false, random_seed: nil)
  @n_splits = n_splits
  @shuffle = shuffle
  @random_seed = random_seed
  @random_seed ||= srand
  @rng = Random.new(@random_seed)
end

Instance Attribute Details

#rngRandom (readonly)

Return the random generator for shuffling the dataset.

Returns:

  • (Random)


27
28
29
# File 'lib/svmkit/model_selection/k_fold.rb', line 27

def rng
  @rng
end

#shuffleBoolean (readonly)

Return the flag indicating whether to shuffle the dataset.

Returns:

  • (Boolean)


23
24
25
# File 'lib/svmkit/model_selection/k_fold.rb', line 23

def shuffle
  @shuffle
end

Instance Method Details

#split(x, y) ⇒ Array

Generate data indices for K-fold cross validation.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The dataset to be used to generate data indices for K-fold cross validation.

  • y (Numo::Int32)

    (shape: [n_samples]) The labels to be used to generate data indices for stratified K-fold cross validation. This argument exists to unify the interface between the K-fold methods, it is not used in the method.

Returns:

  • (Array)

    The set of data indices for constructing the training and testing dataset in each fold.



50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# File 'lib/svmkit/model_selection/k_fold.rb', line 50

def split(x, y) # rubocop:disable Lint/UnusedMethodArgument
  # Initialize and check some variables.
  n_samples, = x.shape
  unless @n_splits.between?(2, n_samples)
    raise ArgumentError,
          'The value of n_splits must be not less than 2 and not more than the number of samples.'
  end
  # Splits dataset ids to each fold.
  dataset_ids = [*0...n_samples]
  dataset_ids.shuffle!(random: @rng) if @shuffle
  fold_sets = Array.new(@n_splits) do |n|
    n_fold_samples = n_samples / @n_splits
    n_fold_samples += 1 if n < n_samples % @n_splits
    dataset_ids.shift(n_fold_samples)
  end
  # Returns array consisting of the training and testing ids for each fold.
  Array.new(@n_splits) do |n|
    train_ids = fold_sets.select.with_index { |_, id| id != n }.flatten
    test_ids = fold_sets[n]
    [train_ids, test_ids]
  end
end