Class: Spark::Mllib::SVMWithSGD

Inherits:
ClassificationMethodBase show all
Defined in:
lib/spark/mllib/classification/svm.rb

Constant Summary collapse

DEFAULT_OPTIONS =
{
  iterations: 100,
  step: 1.0,
  reg_param: 0.01,
  mini_batch_fraction: 1.0,
  initial_weights: nil,
  reg_type: 'l2',
  intercept: false,
  validate: true
}

Class Method Summary collapse

Class Method Details

.train(rdd, options = {}) ⇒ Object

Train a support vector machine on the given data.

rdd

The training data, an RDD of LabeledPoint.

iterations

The number of iterations (default: 100).

step

The step parameter used in SGD (default: 1.0).

reg_param

The regularizer parameter (default: 0.01).

mini_batch_fraction

Fraction of data to be used for each SGD iteration.

initial_weights

The initial weights (default: nil).

reg_type

The type of regularizer used for training our model (default: “l2”).

Allowed values:

  • “l1” for using L1 regularization

  • “l2” for using L2 regularization

  • nil for no regularization

intercept

Boolean parameter which indicates the use or not of the augmented representation for training data (i.e. whether bias features are activated or not). (default: false)

validateData

Boolean parameter which indicates if the algorithm should validate data before training. (default: true)



125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# File 'lib/spark/mllib/classification/svm.rb', line 125

def self.train(rdd, options={})
  super

  weights, intercept = Spark.jb.call(RubyMLLibAPI.new, 'trainSVMModelWithSGD', rdd,
                                     options[:iterations].to_i,
                                     options[:step].to_f,
                                     options[:reg_param].to_f,
                                     options[:mini_batch_fraction].to_f,
                                     options[:initial_weights],
                                     options[:reg_type],
                                     options[:intercept],
                                     options[:validate])

  SVMModel.new(weights, intercept)
end