Class: Rumale::ModelSelection::StratifiedKFold
- Inherits:
-
Object
- Object
- Rumale::ModelSelection::StratifiedKFold
- Includes:
- Base::Splitter
- Defined in:
- lib/rumale/model_selection/stratified_k_fold.rb
Overview
StratifiedKFold is a class that generates the set of data indices for K-fold cross-validation. The proportion of the number of samples in each class will be almost equal for each fold.
Instance Attribute Summary collapse
-
#n_splits ⇒ Integer
readonly
Return the number of folds.
-
#rng ⇒ Random
readonly
Return the random generator for shuffling the dataset.
-
#shuffle ⇒ Boolean
readonly
Return the flag indicating whether to shuffle the dataset.
Instance Method Summary collapse
-
#initialize(n_splits: 3, shuffle: false, random_seed: nil) ⇒ StratifiedKFold
constructor
Create a new data splitter for K-fold cross validation.
-
#split(x, y) ⇒ Array
Generate data indices for stratified K-fold cross validation.
Constructor Details
#initialize(n_splits: 3, shuffle: false, random_seed: nil) ⇒ StratifiedKFold
Create a new data splitter for K-fold cross validation.
38 39 40 41 42 43 44 45 46 47 48 |
# File 'lib/rumale/model_selection/stratified_k_fold.rb', line 38 def initialize(n_splits: 3, shuffle: false, random_seed: nil) check_params_integer(n_splits: n_splits) check_params_boolean(shuffle: shuffle) check_params_type_or_nil(Integer, random_seed: random_seed) check_params_positive(n_splits: n_splits) @n_splits = n_splits @shuffle = shuffle @random_seed = random_seed @random_seed ||= srand @rng = Random.new(@random_seed) end |
Instance Attribute Details
#n_splits ⇒ Integer (readonly)
Return the number of folds.
23 24 25 |
# File 'lib/rumale/model_selection/stratified_k_fold.rb', line 23 def n_splits @n_splits end |
#rng ⇒ Random (readonly)
Return the random generator for shuffling the dataset.
31 32 33 |
# File 'lib/rumale/model_selection/stratified_k_fold.rb', line 31 def rng @rng end |
#shuffle ⇒ Boolean (readonly)
Return the flag indicating whether to shuffle the dataset.
27 28 29 |
# File 'lib/rumale/model_selection/stratified_k_fold.rb', line 27 def shuffle @shuffle end |
Instance Method Details
#split(x, y) ⇒ Array
Generate data indices for stratified K-fold cross validation.
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
# File 'lib/rumale/model_selection/stratified_k_fold.rb', line 58 def split(x, y) check_sample_array(x) check_label_array(y) check_sample_label_size(x, y) # Check the number of samples in each class. unless valid_n_splits?(y) raise ArgumentError, 'The value of n_splits must be not less than 2 and not more than the number of samples in each class.' end # Splits dataset ids of each class to each fold. sub_rng = @rng.dup fold_sets_each_class = y.to_a.uniq.map { |label| fold_sets(y, label, sub_rng) } # Returns array consisting of the training and testing ids for each fold. Array.new(@n_splits) { |fold_id| train_test_sets(fold_sets_each_class, fold_id) } end |