Module: EasyML::Core::ModelCore

Included in:
Model, Model
Defined in:
lib/easy_ml/core/model_core.rb

Instance Attribute Summary collapse

Class Method Summary collapse

Instance Method Summary collapse

Instance Attribute Details

#datasetObject

Returns the value of attribute dataset.



7
8
9
# File 'lib/easy_ml/core/model_core.rb', line 7

def dataset
  @dataset
end

Class Method Details

.included(base) ⇒ Object



9
10
11
12
13
14
15
16
17
18
19
20
21
22
# File 'lib/easy_ml/core/model_core.rb', line 9

def self.included(base)
  base.send(:include, GlueGun::DSL)
  base.send(:extend, CarrierWave::Mount)
  base.send(:mount_uploader, :file, EasyML::Core::Uploaders::ModelUploader)

  base.class_eval do
    validates :task, inclusion: { in: %w[regression classification] }
    validates :task, presence: true
    validate :dataset_is_a_dataset?
    validate :validate_any_metrics?
    validate :validate_metrics_for_task
    before_validation :save_model_file, if: -> { fit? }
  end
end

Instance Method Details

#allowed_metricsObject



81
82
83
84
85
86
87
88
89
90
91
92
# File 'lib/easy_ml/core/model_core.rb', line 81

def allowed_metrics
  return [] unless task.present?

  case task.to_sym
  when :regression
    %w[mean_absolute_error mean_squared_error root_mean_squared_error r2_score]
  when :classification
    %w[accuracy_score precision_score recall_score f1_score auc roc_auc]
  else
    []
  end
end

#cleanupObject



100
101
102
103
104
# File 'lib/easy_ml/core/model_core.rb', line 100

def cleanup
  [carrierwave_dir, model_dir].each do |dir|
    EasyML::FileRotate.new(dir, files_to_keep).cleanup(extension_allowlist)
  end
end

#cleanup!Object



94
95
96
97
98
# File 'lib/easy_ml/core/model_core.rb', line 94

def cleanup!
  [carrierwave_dir, model_dir].each do |dir|
    EasyML::FileRotate.new(dir, []).cleanup(extension_allowlist)
  end
end

#decode_labels(ys, col: nil) ⇒ Object



51
52
53
# File 'lib/easy_ml/core/model_core.rb', line 51

def decode_labels(ys, col: nil)
  dataset.decode_labels(ys, col: col)
end

#evaluate(y_pred: nil, y_true: nil, x_true: nil, evaluator: nil) ⇒ Object



55
56
57
58
59
# File 'lib/easy_ml/core/model_core.rb', line 55

def evaluate(y_pred: nil, y_true: nil, x_true: nil, evaluator: nil)
  evaluator ||= self.evaluator
  EasyML::Core::ModelEvaluator.evaluate(model: self, y_pred: y_pred, y_true: y_true, x_true: x_true,
                                        evaluator: evaluator)
end

#fit(x_train: nil, y_train: nil, x_valid: nil, y_valid: nil) ⇒ Object



24
25
26
27
28
29
30
31
32
# File 'lib/easy_ml/core/model_core.rb', line 24

def fit(x_train: nil, y_train: nil, x_valid: nil, y_valid: nil)
  if x_train.nil?
    dataset.refresh!
    train_in_batches
  else
    train(x_train, y_train, x_valid, y_valid)
  end
  @is_fit = true
end

#fit?Boolean

Returns:

  • (Boolean)


106
107
108
# File 'lib/easy_ml/core/model_core.rb', line 106

def fit?
  @is_fit == true
end

#get_paramsObject



77
78
79
# File 'lib/easy_ml/core/model_core.rb', line 77

def get_params
  @hyperparameters.to_h
end

#loadObject

Raises:

  • (NotImplementedError)


38
39
40
# File 'lib/easy_ml/core/model_core.rb', line 38

def load
  raise NotImplementedError, "Subclasses must implement load method"
end

#predict(xs) ⇒ Object

Raises:

  • (NotImplementedError)


34
35
36
# File 'lib/easy_ml/core/model_core.rb', line 34

def predict(xs)
  raise NotImplementedError, "Subclasses must implement predict method"
end

#saveObject



46
47
48
49
# File 'lib/easy_ml/core/model_core.rb', line 46

def save
  super if defined?(super) && self.class.superclass.method_defined?(:save)
  save_model_file
end

#save_model_fileObject



61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# File 'lib/easy_ml/core/model_core.rb', line 61

def save_model_file
  raise "No trained model! Need to train model before saving (call model.fit)" unless fit?

  path = File.join(model_dir, "#{version}.json")
  ensure_directory_exists(File.dirname(path))

  _save_model_file(path)

  File.open(path) do |f|
    self.file = f
  end
  file.store!

  cleanup
end