Class: Rumale::Tree::BaseDecisionTree

Inherits:
Object
  • Object
show all
Includes:
Base::BaseEstimator
Defined in:
lib/rumale/tree/base_decision_tree.rb

Overview

BaseDecisionTree is an abstract class for implementation of decision tree-based estimator. This class is used internally.

Instance Attribute Summary

Attributes included from Base::BaseEstimator

#params

Instance Method Summary collapse

Constructor Details

#initialize(criterion: nil, max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) ⇒ BaseDecisionTree

Initialize a decision tree-based estimator.

Parameters:

  • criterion (String) (defaults to: nil)

    The function to evalue spliting point.

  • max_depth (Integer) (defaults to: nil)

    The maximum depth of the tree. If nil is given, decision tree grows without concern for depth.

  • max_leaf_nodes (Integer) (defaults to: nil)

    The maximum number of leaves on decision tree. If nil is given, number of leaves is not limited.

  • min_samples_leaf (Integer) (defaults to: 1)

    The minimum number of samples at a leaf node.

  • max_features (Integer) (defaults to: nil)

    The number of features to consider when searching optimal split point. If nil is given, split process considers all features.

  • random_seed (Integer) (defaults to: nil)

    The seed value using to initialize the random generator. It is used to randomly determine the order of features when deciding spliting point.



26
27
28
29
30
31
32
33
34
35
36
37
38
39
# File 'lib/rumale/tree/base_decision_tree.rb', line 26

def initialize(criterion: nil, max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil)
  @params = {}
  @params[:criterion] = criterion
  @params[:max_depth] = max_depth
  @params[:max_leaf_nodes] = max_leaf_nodes
  @params[:min_samples_leaf] = min_samples_leaf
  @params[:max_features] = max_features
  @params[:random_seed] = random_seed
  @params[:random_seed] ||= srand
  @tree = nil
  @feature_importances = nil
  @n_leaves = nil
  @rng = Random.new(@params[:random_seed])
end

Instance Method Details

#apply(x) ⇒ Numo::Int32

Return the index of the leaf that each sample reached.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The samples to predict the labels.

Returns:

  • (Numo::Int32)

    (shape: [n_samples]) Leaf index for sample.



45
46
47
48
# File 'lib/rumale/tree/base_decision_tree.rb', line 45

def apply(x)
  check_sample_array(x)
  Numo::Int32[*(Array.new(x.shape[0]) { |n| apply_at_node(@tree, x[n, true]) })]
end