Class: LightGBM::Classifier

Inherits:
Model
  • Object
show all
Defined in:
lib/lightgbm/classifier.rb

Instance Attribute Summary

Attributes inherited from Model

#booster

Instance Method Summary collapse

Methods inherited from Model

#best_iteration, #feature_importances, #load_model, #save_model

Constructor Details

#initialize(num_leaves: 31, learning_rate: 0.1, n_estimators: 100, objective: nil, **options) ⇒ Classifier

Returns a new instance of Classifier.



3
4
5
# File 'lib/lightgbm/classifier.rb', line 3

def initialize(num_leaves: 31, learning_rate: 0.1, n_estimators: 100, objective: nil, **options)
  super
end

Instance Method Details

#fit(x, y, eval_set: nil, eval_names: [], categorical_feature: "auto", early_stopping_rounds: nil, verbose: true) ⇒ Object



7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# File 'lib/lightgbm/classifier.rb', line 7

def fit(x, y, eval_set: nil, eval_names: [], categorical_feature: "auto", early_stopping_rounds: nil, verbose: true)
  n_classes = y.uniq.size

  params = @params.dup
  if n_classes > 2
    params[:objective] ||= "multiclass"
    params[:num_class] = n_classes
  else
    params[:objective] ||= "binary"
  end

  train_set = Dataset.new(x, label: y, categorical_feature: categorical_feature, params: params)
  valid_sets = Array(eval_set).map { |v| Dataset.new(v[0], label: v[1], reference: train_set, params: params) }

  @booster = LightGBM.train(params, train_set,
    num_boost_round: @n_estimators,
    early_stopping_rounds: early_stopping_rounds,
    verbose_eval: verbose,
    valid_sets: valid_sets,
    valid_names: eval_names
  )
  nil
end

#predict(data, num_iteration: nil) ⇒ Object



31
32
33
34
35
36
37
38
39
40
41
42
# File 'lib/lightgbm/classifier.rb', line 31

def predict(data, num_iteration: nil)
  y_pred = @booster.predict(data, num_iteration: num_iteration)

  if y_pred.first.is_a?(Array)
    # multiple classes
    y_pred.map do |v|
      v.map.with_index.max_by { |v2, _| v2 }.last
    end
  else
    y_pred.map { |v| v > 0.5 ? 1 : 0 }
  end
end

#predict_proba(data, num_iteration: nil) ⇒ Object



44
45
46
47
48
49
50
51
52
53
# File 'lib/lightgbm/classifier.rb', line 44

def predict_proba(data, num_iteration: nil)
  y_pred = @booster.predict(data, num_iteration: num_iteration)

  if y_pred.first.is_a?(Array)
    # multiple classes
    y_pred
  else
    y_pred.map { |v| [1 - v, v] }
  end
end