Class: Spark::Mllib::LogisticRegressionWithLBFGS

Inherits:
ClassificationMethodBase show all
Defined in:
lib/spark/mllib/classification/logistic_regression.rb

Constant Summary collapse

DEFAULT_OPTIONS =
{
  iterations: 100,
  initial_weights: nil,
  reg_param: 0.01,
  reg_type: 'l2',
  intercept: false,
  corrections: 10,
  tolerance: 0.0001
}

Class Method Summary collapse

Class Method Details

.train(rdd, options = {}) ⇒ Object

Train a logistic regression model on the given data.

Arguments:

rdd

The training data, an RDD of LabeledPoint.

iterations

The number of iterations (default: 100).

initial_weights

The initial weights (default: nil).

reg_param

The regularizer parameter (default: 0.01).

reg_type

The type of regularizer used for training our model (default: “l2”).

Allowed values:

  • “l1” for using L1 regularization

  • “l2” for using L2 regularization

  • nil for no regularization

intercept

Boolean parameter which indicates the use or not of the augmented representation for training data (i.e. whether bias features are activated or not).

corrections

The number of corrections used in the LBFGS update (default: 10).

tolerance

The convergence tolerance of iterations for L-BFGS (default: 0.0001).



214
215
216
217
218
219
220
221
222
223
224
225
226
227
# File 'lib/spark/mllib/classification/logistic_regression.rb', line 214

def self.train(rdd, options={})
  super

  weights, intercept = Spark.jb.call(RubyMLLibAPI.new, 'trainLogisticRegressionModelWithLBFGS', rdd,
                                     options[:iterations].to_i,
                                     options[:initial_weights],
                                     options[:reg_param].to_f,
                                     options[:reg_type],
                                     options[:intercept],
                                     options[:corrections].to_i,
                                     options[:tolerance].to_f)

  LogisticRegressionModel.new(weights, intercept)
end