Class: Rumale::Tree::BaseDecisionTree
- Inherits:
-
Object
- Object
- Rumale::Tree::BaseDecisionTree
- Includes:
- Base::BaseEstimator
- Defined in:
- lib/rumale/tree/base_decision_tree.rb
Overview
BaseDecisionTree is an abstract class for implementation of decision tree-based estimator. This class is used internally.
Direct Known Subclasses
Instance Attribute Summary
Attributes included from Base::BaseEstimator
Instance Method Summary collapse
-
#apply(x) ⇒ Numo::Int32
Return the index of the leaf that each sample reached.
-
#initialize(criterion: nil, max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) ⇒ BaseDecisionTree
constructor
Initialize a decision tree-based estimator.
Constructor Details
#initialize(criterion: nil, max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) ⇒ BaseDecisionTree
Initialize a decision tree-based estimator.
26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
# File 'lib/rumale/tree/base_decision_tree.rb', line 26 def initialize(criterion: nil, max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) @params = {} @params[:criterion] = criterion @params[:max_depth] = max_depth @params[:max_leaf_nodes] = max_leaf_nodes @params[:min_samples_leaf] = min_samples_leaf @params[:max_features] = max_features @params[:random_seed] = random_seed @params[:random_seed] ||= srand @tree = nil @feature_importances = nil @n_leaves = nil @rng = Random.new(@params[:random_seed]) end |
Instance Method Details
#apply(x) ⇒ Numo::Int32
Return the index of the leaf that each sample reached.
45 46 47 48 |
# File 'lib/rumale/tree/base_decision_tree.rb', line 45 def apply(x) check_sample_array(x) Numo::Int32[*(Array.new(x.shape[0]) { |n| apply_at_node(@tree, x[n, true]) })] end |