Class: EasyML::Core::Tuner::Adapters::BaseAdapter

Inherits:
Object
  • Object
show all
Includes:
GlueGun::DSL
Defined in:
lib/easy_ml/core/tuner/adapters/base_adapter.rb

Direct Known Subclasses

XGBoostAdapter

Instance Method Summary collapse

Instance Method Details

#configure_callbacksObject



26
27
28
# File 'lib/easy_ml/core/tuner/adapters/base_adapter.rb', line 26

def configure_callbacks
  raise "Subclasses fof Tuner::Adapter::BaseAdapter must define #configure_callbacks"
end

#deep_merge_defaults(config) ⇒ Object



37
38
39
40
41
42
43
44
45
# File 'lib/easy_ml/core/tuner/adapters/base_adapter.rb', line 37

def deep_merge_defaults(config)
  defaults.deep_merge(config) do |_key, default_value, config_value|
    if default_value.is_a?(Hash) && config_value.is_a?(Hash)
      default_value.merge(config_value)
    else
      config_value
    end
  end
end

#defaultsObject



8
9
10
# File 'lib/easy_ml/core/tuner/adapters/base_adapter.rb', line 8

def defaults
  {}
end

#run_trial(trial) {|model| ... } ⇒ Object

Yields:

  • (model)


19
20
21
22
23
24
# File 'lib/easy_ml/core/tuner/adapters/base_adapter.rb', line 19

def run_trial(trial)
  config = deep_merge_defaults(self.config.clone)
  suggest_parameters(trial, config)
  model.fit
  yield model
end

#suggest_parameter(trial, param_name, config) ⇒ Object



47
48
49
50
51
52
53
54
55
56
57
58
# File 'lib/easy_ml/core/tuner/adapters/base_adapter.rb', line 47

def suggest_parameter(trial, param_name, config)
  param_config = config[param_name]
  min = param_config[:min]
  max = param_config[:max]
  log = param_config[:log]

  if log
    trial.suggest_loguniform(param_name.to_s, min, max)
  else
    trial.suggest_uniform(param_name.to_s, min, max)
  end
end

#suggest_parameters(trial, config) ⇒ Object



30
31
32
33
34
35
# File 'lib/easy_ml/core/tuner/adapters/base_adapter.rb', line 30

def suggest_parameters(trial, config)
  defaults.keys.each do |param_name|
    param_value = suggest_parameter(trial, param_name, config)
    model.hyperparameters.send("#{param_name}=", param_value)
  end
end