Class: SVMKit::Tree::DecisionTreeClassifier
- Inherits:
-
Object
- Object
- SVMKit::Tree::DecisionTreeClassifier
- Includes:
- Base::BaseEstimator, Base::Classifier
- Defined in:
- lib/svmkit/tree/decision_tree_classifier.rb
Overview
DecisionTreeClassifier is a class that implements decision tree for classification.
Instance Attribute Summary collapse
-
#classes ⇒ Numo::Int32
readonly
Return the class labels.
-
#feature_importances ⇒ Numo::DFloat
readonly
Return the importance for each feature.
-
#leaf_labels ⇒ Numo::Int32
readonly
Return the labels assigned each leaf.
-
#rng ⇒ Random
readonly
Return the random generator for performing random sampling in the Pegasos algorithm.
-
#tree ⇒ OpenStruct
readonly
Return the learned tree.
Attributes included from Base::BaseEstimator
Instance Method Summary collapse
-
#apply(x) ⇒ Numo::Int32
Return the index of the leaf that each sample reached.
-
#fit(x, y) ⇒ DecisionTreeClassifier
Fit the model with given training data.
-
#initialize(criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) ⇒ DecisionTreeClassifier
constructor
Create a new classifier with decision tree algorithm.
-
#marshal_dump ⇒ Hash
Dump marshal data.
-
#marshal_load(obj) ⇒ nil
Load marshal data.
-
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
-
#predict_proba(x) ⇒ Numo::DFloat
Predict probability for samples.
Methods included from Base::Classifier
Constructor Details
#initialize(criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) ⇒ DecisionTreeClassifier
Create a new classifier with decision tree algorithm.
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 55 def initialize(criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) @params = {} @params[:criterion] = criterion @params[:max_depth] = max_depth @params[:max_leaf_nodes] = max_leaf_nodes @params[:min_samples_leaf] = min_samples_leaf @params[:max_features] = max_features @params[:random_seed] = random_seed @params[:random_seed] ||= srand @rng = Random.new(@params[:random_seed]) @tree = nil @classes = nil @feature_importances = nil @n_leaves = nil @leaf_labels = nil end |
Instance Attribute Details
#classes ⇒ Numo::Int32 (readonly)
Return the class labels.
25 26 27 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 25 def classes @classes end |
#feature_importances ⇒ Numo::DFloat (readonly)
Return the importance for each feature.
29 30 31 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 29 def feature_importances @feature_importances end |
#leaf_labels ⇒ Numo::Int32 (readonly)
Return the labels assigned each leaf.
41 42 43 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 41 def leaf_labels @leaf_labels end |
#rng ⇒ Random (readonly)
Return the random generator for performing random sampling in the Pegasos algorithm.
37 38 39 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 37 def rng @rng end |
#tree ⇒ OpenStruct (readonly)
Return the learned tree.
33 34 35 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 33 def tree @tree end |
Instance Method Details
#apply(x) ⇒ Numo::Int32
Return the index of the leaf that each sample reached.
109 110 111 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 109 def apply(x) Numo::Int32[*(Array.new(x.shape[0]) { |n| apply_at_node(@tree, x[n, true]) })] end |
#fit(x, y) ⇒ DecisionTreeClassifier
Fit the model with given training data.
78 79 80 81 82 83 84 85 86 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 78 def fit(x, y) n_samples, n_features = x.shape @params[:max_features] = n_features unless @params[:max_features].is_a?(Integer) @params[:max_features] = [[1, @params[:max_features]].max, n_features].min @classes = Numo::Int32.asarray(y.to_a.uniq.sort) build_tree(x, y) eval_importance(n_samples, n_features) self end |
#marshal_dump ⇒ Hash
Dump marshal data.
115 116 117 118 119 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 115 def marshal_dump { params: @params, classes: @classes, tree: @tree, feature_importances: @feature_importances, leaf_labels: @leaf_labels, rng: @rng } end |
#marshal_load(obj) ⇒ nil
Load marshal data.
123 124 125 126 127 128 129 130 131 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 123 def marshal_load(obj) @params = obj[:params] @classes = obj[:classes] @tree = obj[:tree] @feature_importances = obj[:feature_importances] @leaf_labels = obj[:leaf_labels] @rng = obj[:rng] nil end |
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
92 93 94 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 92 def predict(x) @leaf_labels[apply(x)] end |
#predict_proba(x) ⇒ Numo::DFloat
Predict probability for samples.
100 101 102 103 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 100 def predict_proba(x) probs = Numo::DFloat[*(Array.new(x.shape[0]) { |n| predict_at_node(@tree, x[n, true]) })] probs[true, @classes] end |