Class: Spark::Mllib::RegressionMethodBase
- Inherits:
-
Object
- Object
- Spark::Mllib::RegressionMethodBase
- Defined in:
- lib/spark/mllib/regression/common.rb
Overview
RegressionMethodBase
Parent for regression methods
Direct Known Subclasses
ClassificationMethodBase, LassoWithSGD, LinearRegressionWithSGD, RidgeRegressionWithSGD
Class Method Summary collapse
Class Method Details
.train(rdd, options) ⇒ Object
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
# File 'lib/spark/mllib/regression/common.rb', line 48 def self.train(rdd, ) # String keys to symbols .symbolize_keys! # Reverse merge self::DEFAULT_OPTIONS.each do |key, value| if .has_key?(key) # value from user else [key] = value end end # Validation first = rdd.first unless first.is_a?(LabeledPoint) raise Spark::MllibError, "RDD should contains LabeledPoint, got #{first.class}" end # Initial weights is optional for user (not for Spark) [:initial_weights] = Vectors.to_vector([:initial_weights] || [0.0] * first.features.size) end |