Class: Rumale::Ensemble::ExtraTreesClassifier
- Inherits:
-
RandomForestClassifier
- Object
- RandomForestClassifier
- Rumale::Ensemble::ExtraTreesClassifier
- Defined in:
- lib/rumale/ensemble/extra_trees_classifier.rb
Overview
ExtraTreesClassifier is a class that implements extremely randomized trees for classification. The algorithm of extremely randomized trees is similar to random forest. The features of the algorithm of extremely randomized trees are not to apply the bagging procedure and to randomly select the threshold for splitting feature space.
Reference
-
Geurts, D. Ernst, and L. Wehenkel, “Extremely randomized trees,” Machine Learning, vol. 63 (1), pp. 3–42, 2006.
-
Instance Attribute Summary collapse
-
#classes ⇒ Numo::Int32
readonly
Return the class labels.
-
#estimators ⇒ Array<ExtraTreeClassifier>
readonly
Return the set of estimators.
-
#feature_importances ⇒ Numo::DFloat
readonly
Return the importance for each feature.
-
#rng ⇒ Random
readonly
Return the random generator for random selection of feature index.
Attributes included from Base::BaseEstimator
Instance Method Summary collapse
-
#apply(x) ⇒ Numo::Int32
Return the index of the leaf that each sample reached.
-
#fit(x, y) ⇒ ExtraTreesClassifier
Fit the model with given training data.
-
#initialize(n_estimators: 10, criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, n_jobs: nil, random_seed: nil) ⇒ ExtraTreesClassifier
constructor
Create a new classifier with extremely randomized trees.
-
#marshal_dump ⇒ Hash
Dump marshal data.
-
#marshal_load(obj) ⇒ nil
Load marshal data.
-
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
-
#predict_proba(x) ⇒ Numo::DFloat
Predict probability for samples.
Methods included from Base::Classifier
Constructor Details
#initialize(n_estimators: 10, criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, n_jobs: nil, random_seed: nil) ⇒ ExtraTreesClassifier
Create a new classifier with extremely randomized trees.
56 57 58 59 60 61 62 63 64 65 66 67 |
# File 'lib/rumale/ensemble/extra_trees_classifier.rb', line 56 def initialize(n_estimators: 10, criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, n_jobs: nil, random_seed: nil) check_params_type_or_nil(Integer, max_depth: max_depth, max_leaf_nodes: max_leaf_nodes, max_features: max_features, n_jobs: n_jobs, random_seed: random_seed) check_params_integer(n_estimators: n_estimators, min_samples_leaf: min_samples_leaf) check_params_string(criterion: criterion) check_params_positive(n_estimators: n_estimators, max_depth: max_depth, max_leaf_nodes: max_leaf_nodes, min_samples_leaf: min_samples_leaf, max_features: max_features) super end |
Instance Attribute Details
#classes ⇒ Numo::Int32 (readonly)
Return the class labels.
29 30 31 |
# File 'lib/rumale/ensemble/extra_trees_classifier.rb', line 29 def classes @classes end |
#estimators ⇒ Array<ExtraTreeClassifier> (readonly)
Return the set of estimators.
25 26 27 |
# File 'lib/rumale/ensemble/extra_trees_classifier.rb', line 25 def estimators @estimators end |
#feature_importances ⇒ Numo::DFloat (readonly)
Return the importance for each feature.
33 34 35 |
# File 'lib/rumale/ensemble/extra_trees_classifier.rb', line 33 def feature_importances @feature_importances end |
#rng ⇒ Random (readonly)
Return the random generator for random selection of feature index.
37 38 39 |
# File 'lib/rumale/ensemble/extra_trees_classifier.rb', line 37 def rng @rng end |
Instance Method Details
#apply(x) ⇒ Numo::Int32
Return the index of the leaf that each sample reached.
123 124 125 126 |
# File 'lib/rumale/ensemble/extra_trees_classifier.rb', line 123 def apply(x) check_sample_array(x) super end |
#fit(x, y) ⇒ ExtraTreesClassifier
Fit the model with given training data.
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
# File 'lib/rumale/ensemble/extra_trees_classifier.rb', line 74 def fit(x, y) check_sample_array(x) check_label_array(y) check_sample_label_size(x, y) # Initialize some variables. n_features = x.shape[1] @params[:max_features] = Math.sqrt(n_features).to_i unless @params[:max_features].is_a?(Integer) @params[:max_features] = [[1, @params[:max_features]].max, n_features].min @classes = Numo::Int32.asarray(y.to_a.uniq.sort) sub_rng = @rng.dup # Construct trees. rng_seeds = Array.new(@params[:n_estimators]) { sub_rng.rand(Rumale::Values.int_max) } @estimators = if enable_parallel? parallel_map(@params[:n_estimators]) { |n| plant_tree(rng_seeds[n]).fit(x, y) } else Array.new(@params[:n_estimators]) { |n| plant_tree(rng_seeds[n]).fit(x, y) } end @feature_importances = if enable_parallel? parallel_map(@params[:n_estimators]) { |n| @estimators[n].feature_importances }.reduce(&:+) else @estimators.map(&:feature_importances).reduce(&:+) end @feature_importances /= @feature_importances.sum self end |
#marshal_dump ⇒ Hash
Dump marshal data.
130 131 132 |
# File 'lib/rumale/ensemble/extra_trees_classifier.rb', line 130 def marshal_dump super end |
#marshal_load(obj) ⇒ nil
Load marshal data.
136 137 138 |
# File 'lib/rumale/ensemble/extra_trees_classifier.rb', line 136 def marshal_load(obj) super end |
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
105 106 107 108 |
# File 'lib/rumale/ensemble/extra_trees_classifier.rb', line 105 def predict(x) check_sample_array(x) super end |
#predict_proba(x) ⇒ Numo::DFloat
Predict probability for samples.
114 115 116 117 |
# File 'lib/rumale/ensemble/extra_trees_classifier.rb', line 114 def predict_proba(x) check_sample_array(x) super end |