Class: Eps::BaseEstimator

Inherits:
Object
  • Object
show all
Defined in:
lib/eps/base_estimator.rb

Direct Known Subclasses

LightGBM, LinearRegression, NaiveBayes

Class Method Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(data = nil, y = nil, **options) ⇒ BaseEstimator

Returns a new instance of BaseEstimator.



3
4
5
6
7
8
9
# File 'lib/eps/base_estimator.rb', line 3

def initialize(data = nil, y = nil, **options)
  @options = options.dup
  @trained = false
  @text_encoders = {}
  # TODO better pattern - don't pass most options to train
  train(data, y, **options) if data
end

Class Method Details

.load_pmml(pmml) ⇒ Object



28
29
30
31
32
33
# File 'lib/eps/base_estimator.rb', line 28

def self.load_pmml(pmml)
  model = new
  model.instance_variable_set("@evaluator", PMML.load(pmml))
  model.instance_variable_set("@pmml", pmml.respond_to?(:to_xml) ? pmml.to_xml : pmml) # cache data
  model
end

Instance Method Details

#evaluate(data, y = nil, target: nil, weight: nil) ⇒ Object



19
20
21
22
# File 'lib/eps/base_estimator.rb', line 19

def evaluate(data, y = nil, target: nil, weight: nil)
  data, target = prep_data(data, y, target || @target, weight)
  Eps.metrics(data.label, predict(data), weight: data.weight)
end

#predict(data) ⇒ Object



11
12
13
# File 'lib/eps/base_estimator.rb', line 11

def predict(data)
  _predict(data, false)
end

#predict_probability(data) ⇒ Object



15
16
17
# File 'lib/eps/base_estimator.rb', line 15

def predict_probability(data)
  _predict(data, true)
end

#summary(extended: false) ⇒ Object



35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# File 'lib/eps/base_estimator.rb', line 35

def summary(extended: false)
  raise "Summary not available for loaded models" unless @trained

  str = String.new("")

  if @validation_set
    y_true = @validation_set.label
    y_pred = predict(@validation_set)

    case @target_type
    when "numeric"
      metric_name = "RMSE"
      v = Metrics.rmse(y_true, y_pred, weight: @validation_set.weight)
      metric_value = v.round >= 1000 ? v.round.to_s : "%.3g" % v
    else
      metric_name = "accuracy"
      metric_value = "%.1f%%" % (100 * Metrics.accuracy(y_true, y_pred, weight: @validation_set.weight)).round(1)
    end
    str << "Validation %s: %s\n\n"  % [metric_name, metric_value]
  end

  str << _summary(extended: extended)
  str
end

#to_pmmlObject



24
25
26
# File 'lib/eps/base_estimator.rb', line 24

def to_pmml
  @pmml ||= PMML.generate(self)
end