Class: Spark::Mllib::RegressionMethodBase

Inherits:
Object
  • Object
show all
Defined in:
lib/spark/mllib/regression/common.rb

Overview

RegressionMethodBase

Parent for regression methods

Class Method Summary collapse

Class Method Details

.train(rdd, options) ⇒ Object



48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# File 'lib/spark/mllib/regression/common.rb', line 48

def self.train(rdd, options)
  # String keys to symbols
  options.symbolize_keys!

  # Reverse merge
  self::DEFAULT_OPTIONS.each do |key, value|
    if options.has_key?(key)
      # value from user
    else
      options[key] = value
    end
  end

  # Validation
  first = rdd.first
  unless first.is_a?(LabeledPoint)
    raise Spark::MllibError, "RDD should contains LabeledPoint, got #{first.class}"
  end

  # Initial weights is optional for user (not for Spark)
  options[:initial_weights] = Vectors.to_vector(options[:initial_weights] || [0.0] * first.features.size)
end