Class: SVMKit::Ensemble::RandomForestClassifier
- Inherits:
-
Object
- Object
- SVMKit::Ensemble::RandomForestClassifier
- Includes:
- Base::BaseEstimator, Base::Classifier
- Defined in:
- lib/svmkit/ensemble/random_forest_classifier.rb
Overview
RandomForestClassifier is a class that implements random forest for classification.
Instance Attribute Summary collapse
-
#classes ⇒ Numo::Int32
readonly
Return the class labels.
-
#estimators ⇒ Array<DecisionTreeClassifier>
readonly
Return the set of estimators.
-
#feature_importances ⇒ Numo::DFloat
readonly
Return the importance for each feature.
-
#rng ⇒ Random
readonly
Return the random generator for performing random sampling in the Pegasos algorithm.
Attributes included from Base::BaseEstimator
Instance Method Summary collapse
-
#apply(x) ⇒ Numo::Int32
Return the index of the leaf that each sample reached.
-
#fit(x, y) ⇒ RandomForestClassifier
Fit the model with given training data.
-
#initialize(n_estimators: 10, criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) ⇒ RandomForestClassifier
constructor
Create a new classifier with random forest.
-
#marshal_dump ⇒ Hash
Dump marshal data.
-
#marshal_load(obj) ⇒ nil
Load marshal data.
-
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
-
#predict_proba(x) ⇒ Numo::DFloat
Predict probability for samples.
Methods included from Base::Classifier
Constructor Details
#initialize(n_estimators: 10, criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) ⇒ RandomForestClassifier
Create a new classifier with random forest.
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 51 def initialize(n_estimators: 10, criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) @params = {} @params[:n_estimators] = n_estimators @params[:criterion] = criterion @params[:max_depth] = max_depth @params[:max_leaf_nodes] = max_leaf_nodes @params[:min_samples_leaf] = min_samples_leaf @params[:max_features] = max_features @params[:random_seed] = random_seed @params[:random_seed] ||= srand @rng = Random.new(@params[:random_seed]) @estimators = nil @classes = nil @feature_importances = nil end |
Instance Attribute Details
#classes ⇒ Numo::Int32 (readonly)
Return the class labels.
28 29 30 |
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 28 def classes @classes end |
#estimators ⇒ Array<DecisionTreeClassifier> (readonly)
Return the set of estimators.
24 25 26 |
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 24 def estimators @estimators end |
#feature_importances ⇒ Numo::DFloat (readonly)
Return the importance for each feature.
32 33 34 |
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 32 def feature_importances @feature_importances end |
#rng ⇒ Random (readonly)
Return the random generator for performing random sampling in the Pegasos algorithm.
36 37 38 |
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 36 def rng @rng end |
Instance Method Details
#apply(x) ⇒ Numo::Int32
Return the index of the leaf that each sample reached.
138 139 140 |
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 138 def apply(x) Numo::Int32[*Array.new(@params[:n_estimators]) { |n| @estimators[n].apply(x) }].transpose end |
#fit(x, y) ⇒ RandomForestClassifier
Fit the model with given training data.
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 73 def fit(x, y) # Initialize some variables. n_samples, n_features = x.shape @params[:max_features] = n_features unless @params[:max_features].is_a?(Integer) @params[:max_features] = [[1, @params[:max_features]].max, Math.sqrt(n_features).to_i].min @classes = Numo::Int32.asarray(y.to_a.uniq.sort) # Construct forest. @estimators = Array.new(@params[:n_estimators]) do |_n| tree = Tree::DecisionTreeClassifier.new( criterion: @params[:criterion], max_depth: @params[:max_depth], max_leaf_nodes: @params[:max_leaf_nodes], min_samples_leaf: @params[:min_samples_leaf], max_features: @params[:max_features], random_seed: @params[:random_seed] ) bootstrap_ids = Array.new(n_samples) { @rng.rand(0...n_samples) } tree.fit(x[bootstrap_ids, true], y[bootstrap_ids]) end # Calculate feature importances. @feature_importances = Numo::DFloat.zeros(n_features) @estimators.each { |tree| @feature_importances += tree.feature_importances } @feature_importances /= @feature_importances.sum self end |
#marshal_dump ⇒ Hash
Dump marshal data.
144 145 146 147 |
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 144 def marshal_dump { params: @params, estimators: @estimators, classes: @classes, feature_importances: @feature_importances, rng: @rng } end |
#marshal_load(obj) ⇒ nil
Load marshal data.
151 152 153 154 155 156 157 158 |
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 151 def marshal_load(obj) @params = obj[:params] @estimators = obj[:estimators] @classes = obj[:classes] @feature_importances = obj[:feature_importances] @rng = obj[:rng] nil end |
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 100 def predict(x) n_samples, = x.shape n_classes = @classes.size classes_arr = @classes.to_a ballot_box = Numo::DFloat.zeros(n_samples, n_classes) @estimators.each do |tree| predicted = tree.predict(x) n_samples.times do |n| class_id = classes_arr.index(predicted[n]) ballot_box[n, class_id] += 1.0 unless class_id.nil? end end Numo::Int32[*Array.new(n_samples) { |n| @classes[ballot_box[n, true].max_index] }] end |
#predict_proba(x) ⇒ Numo::DFloat
Predict probability for samples.
119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 119 def predict_proba(x) n_samples, = x.shape n_classes = @classes.size classes_arr = @classes.to_a ballot_box = Numo::DFloat.zeros(n_samples, n_classes) @estimators.each do |tree| probs = tree.predict_proba(x) tree.classes.size.times do |n| class_id = classes_arr.index(tree.classes[n]) ballot_box[true, class_id] += probs[true, n] unless class_id.nil? end end (ballot_box.transpose / ballot_box.sum(axis: 1)).transpose end |