Class: EasyML::Core::Tuner

Inherits:
Object
  • Object
show all
Includes:
GlueGun::DSL
Defined in:
lib/easy_ml/core/tuner.rb,
lib/easy_ml/core/tuner/adapters.rb,
lib/easy_ml/core/tuner/adapters/base_adapter.rb,
lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb

Defined Under Namespace

Modules: Adapters

Instance Attribute Summary collapse

Instance Method Summary collapse

Instance Attribute Details

#resultsObject

Returns the value of attribute results.



18
19
20
# File 'lib/easy_ml/core/tuner.rb', line 18

def results
  @results
end

#studyObject

Returns the value of attribute study.



18
19
20
# File 'lib/easy_ml/core/tuner.rb', line 18

def study
  @study
end

Instance Method Details

#loggers(_study, trial) ⇒ Object



38
39
40
41
42
# File 'lib/easy_ml/core/tuner.rb', line 38

def loggers(_study, trial)
  return unless trial.state.name == "FAIL"

  raise "Trial failed: Stopping optimization."
end

#pick_adapterObject



79
80
81
82
83
84
# File 'lib/easy_ml/core/tuner.rb', line 79

def pick_adapter
  case model
  when EasyML::Core::Models::XGBoost, EasyML::Models::XGBoost
    Adapters::XGBoostAdapter
  end
end

#set_defaults!Object

Raises:

  • (ArgumentError)


94
95
96
97
98
99
100
101
102
# File 'lib/easy_ml/core/tuner.rb', line 94

def set_defaults!
  unless task.present?
    self.task = model.task
    raise ArgumentError, "EasyML::Core::Tuner requires task (regression or classification)" unless task.present?
  end
  raise ArgumentError, "Objectives required for EasyML::Core::Tuner" unless objective.present?

  self.metrics = EasyML::Core::Model.new(task: task).allowed_metrics if metrics.nil? || metrics.empty?
end

#tuneObject



44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# File 'lib/easy_ml/core/tuner.rb', line 44

def tune
  set_defaults!

  @study = Optuna::Study.new
  @results = []
  model.task = task
  x_true, y_true = model.dataset.test(split_ys: true)
  tune_started_at = EST.now
  adapter = pick_adapter.new(model: model, config: config, tune_started_at: tune_started_at, y_true: y_true,
                             x_true: x_true)
  adapter.configure_callbacks

  @study.optimize(n_trials: n_trials, callbacks: [method(:loggers)]) do |trial|
    run_metrics = tune_once(trial, x_true, y_true, adapter)

    result = if model.evaluator.present?
               if model.evaluator_metric.present?
                 run_metrics[model.evaluator_metric]
               else
                 run_metrics[:custom]
               end
             else
               run_metrics[objective.to_sym]
             end
    @results.push(result)
    result
  rescue StandardError => e
    puts "Optuna failed with: #{e.message}"
  end

  raise "Optuna study failed" unless @study.respond_to?(:best_trial)

  @study.best_trial.params
end

#tune_once(trial, x_true, y_true, adapter) ⇒ Object



86
87
88
89
90
91
92
# File 'lib/easy_ml/core/tuner.rb', line 86

def tune_once(trial, x_true, y_true, adapter)
  adapter.run_trial(trial) do |model|
    y_pred = model.predict(y_true)
    model.metrics = metrics
    model.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
  end
end