Class: EasyML::Core::Tuner::Adapters::XGBoostAdapter

Inherits:
BaseAdapter
  • Object
show all
Includes:
GlueGun::DSL
Defined in:
lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb

Instance Method Summary collapse

Methods inherited from BaseAdapter

#deep_merge_defaults, #run_trial, #suggest_parameter, #suggest_parameters

Instance Method Details

#configure_callbacksObject



28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# File 'lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb', line 28

def configure_callbacks
  model.customize_callbacks do |callbacks|
    return unless callbacks.present?

    wandb_callback = callbacks.detect { |cb| cb.class == Wandb::XGBoostCallback }
    return unless wandb_callback.present?

    wandb_callback.project_name = "#{wandb_callback.project_name}_#{tune_started_at.strftime("%Y_%m_%d_%H_%M_%S")}"
    wandb_callback.custom_loggers = [
      lambda do |booster, _epoch, _hist|
        dtrain = model.send(:preprocess, x_true, y_true)
        y_pred = booster.predict(dtrain)
        metrics = model.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
        Wandb.log(metrics)
      end
    ]
  end
end

#defaultsObject



10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# File 'lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb', line 10

def defaults
  {
    learning_rate: {
      min: 0.001,
      max: 0.1,
      log: true
    },
    n_estimators: {
      min: 100,
      max: 1_000
    },
    max_depth: {
      min: 2,
      max: 20
    }
  }
end