Class: Rumale::Tree::DecisionTreeClassifier
- Inherits:
-
BaseDecisionTree
- Object
- BaseDecisionTree
- Rumale::Tree::DecisionTreeClassifier
- Includes:
- Base::Classifier
- Defined in:
- lib/rumale/tree/decision_tree_classifier.rb
Overview
DecisionTreeClassifier is a class that implements decision tree for classification.
Direct Known Subclasses
Instance Attribute Summary collapse
-
#classes ⇒ Numo::Int32
readonly
Return the class labels.
-
#feature_importances ⇒ Numo::DFloat
readonly
Return the importance for each feature.
-
#leaf_labels ⇒ Numo::Int32
readonly
Return the labels assigned each leaf.
-
#rng ⇒ Random
readonly
Return the random generator for random selection of feature index.
-
#tree ⇒ Node
readonly
Return the learned tree.
Attributes included from Base::BaseEstimator
Instance Method Summary collapse
-
#fit(x, y) ⇒ DecisionTreeClassifier
Fit the model with given training data.
-
#initialize(criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) ⇒ DecisionTreeClassifier
constructor
Create a new classifier with decision tree algorithm.
-
#marshal_dump ⇒ Hash
Dump marshal data.
-
#marshal_load(obj) ⇒ nil
Load marshal data.
-
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
-
#predict_proba(x) ⇒ Numo::DFloat
Predict probability for samples.
Methods included from Base::Classifier
Methods inherited from BaseDecisionTree
Constructor Details
#initialize(criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) ⇒ DecisionTreeClassifier
Create a new classifier with decision tree algorithm.
54 55 56 57 58 59 60 61 62 63 64 |
# File 'lib/rumale/tree/decision_tree_classifier.rb', line 54 def initialize(criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) check_params_type_or_nil(Integer, max_depth: max_depth, max_leaf_nodes: max_leaf_nodes, max_features: max_features, random_seed: random_seed) check_params_integer(min_samples_leaf: min_samples_leaf) check_params_string(criterion: criterion) check_params_positive(max_depth: max_depth, max_leaf_nodes: max_leaf_nodes, min_samples_leaf: min_samples_leaf, max_features: max_features) super @leaf_labels = nil end |
Instance Attribute Details
#classes ⇒ Numo::Int32 (readonly)
Return the class labels.
24 25 26 |
# File 'lib/rumale/tree/decision_tree_classifier.rb', line 24 def classes @classes end |
#feature_importances ⇒ Numo::DFloat (readonly)
Return the importance for each feature.
28 29 30 |
# File 'lib/rumale/tree/decision_tree_classifier.rb', line 28 def feature_importances @feature_importances end |
#leaf_labels ⇒ Numo::Int32 (readonly)
Return the labels assigned each leaf.
40 41 42 |
# File 'lib/rumale/tree/decision_tree_classifier.rb', line 40 def leaf_labels @leaf_labels end |
#rng ⇒ Random (readonly)
Return the random generator for random selection of feature index.
36 37 38 |
# File 'lib/rumale/tree/decision_tree_classifier.rb', line 36 def rng @rng end |
#tree ⇒ Node (readonly)
Return the learned tree.
32 33 34 |
# File 'lib/rumale/tree/decision_tree_classifier.rb', line 32 def tree @tree end |
Instance Method Details
#fit(x, y) ⇒ DecisionTreeClassifier
Fit the model with given training data.
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
# File 'lib/rumale/tree/decision_tree_classifier.rb', line 71 def fit(x, y) check_sample_array(x) check_label_array(y) check_sample_label_size(x, y) n_samples, n_features = x.shape @params[:max_features] = n_features if @params[:max_features].nil? @params[:max_features] = [@params[:max_features], n_features].min uniq_y = y.to_a.uniq.sort @classes = Numo::Int32.asarray(uniq_y) @n_leaves = 0 @leaf_labels = [] @sub_rng = @rng.dup build_tree(x, y.map { |v| uniq_y.index(v) }) eval_importance(n_samples, n_features) @leaf_labels = Numo::Int32[*@leaf_labels] self end |
#marshal_dump ⇒ Hash
Dump marshal data.
109 110 111 112 113 114 115 116 |
# File 'lib/rumale/tree/decision_tree_classifier.rb', line 109 def marshal_dump { params: @params, classes: @classes, tree: @tree, feature_importances: @feature_importances, leaf_labels: @leaf_labels, rng: @rng } end |
#marshal_load(obj) ⇒ nil
Load marshal data.
120 121 122 123 124 125 126 127 128 |
# File 'lib/rumale/tree/decision_tree_classifier.rb', line 120 def marshal_load(obj) @params = obj[:params] @classes = obj[:classes] @tree = obj[:tree] @feature_importances = obj[:feature_importances] @leaf_labels = obj[:leaf_labels] @rng = obj[:rng] nil end |
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
93 94 95 96 |
# File 'lib/rumale/tree/decision_tree_classifier.rb', line 93 def predict(x) check_sample_array(x) @leaf_labels[apply(x)].dup end |
#predict_proba(x) ⇒ Numo::DFloat
Predict probability for samples.
102 103 104 105 |
# File 'lib/rumale/tree/decision_tree_classifier.rb', line 102 def predict_proba(x) check_sample_array(x) Numo::DFloat[*(Array.new(x.shape[0]) { |n| predict_proba_at_node(@tree, x[n, true]) })] end |