Class: Spark::Mllib::LogisticRegressionModel

Inherits:
ClassificationModel show all
Defined in:
lib/spark/mllib/classification/logistic_regression.rb

Overview

LogisticRegressionModel

A linear binary classification model derived from logistic regression.

Examples:

Spark::Mllib.import

# Dense vectors
data = [
  LabeledPoint.new(0.0, [0.0, 1.0]),
  LabeledPoint.new(1.0, [1.0, 0.0]),
]
lrm = LogisticRegressionWithSGD.train($sc.parallelize(data))

lrm.predict([1.0, 0.0])
# => 1
lrm.predict([0.0, 1.0])
# => 0

lrm.clear_threshold
lrm.predict([0.0, 1.0])
# => 0.123...

# Sparse vectors
data = [
  LabeledPoint.new(0.0, SparseVector.new(2, {0 => 0.0})),
  LabeledPoint.new(1.0, SparseVector.new(2, {1 => 1.0})),
  LabeledPoint.new(0.0, SparseVector.new(2, {0 => 1.0})),
  LabeledPoint.new(1.0, SparseVector.new(2, {1 => 2.0}))
]
lrm = LogisticRegressionWithSGD.train($sc.parallelize(data))

lrm.predict([0.0, 1.0])
# => 1
lrm.predict([1.0, 0.0])
# => 0
lrm.predict(SparseVector.new(2, {1 => 1.0}))
# => 1
lrm.predict(SparseVector.new(2, {0 => 1.0}))
# => 0

# LogisticRegressionWithLBFGS
data = [
  LabeledPoint.new(0.0, [0.0, 1.0]),
  LabeledPoint.new(1.0, [1.0, 0.0]),
]
lrm = LogisticRegressionWithLBFGS.train($sc.parallelize(data))

lrm.predict([1.0, 0.0])
# => 1
lrm.predict([0.0, 1.0])
# => 0

Instance Attribute Summary

Attributes inherited from ClassificationModel

#intercept, #threshold, #weights

Instance Method Summary collapse

Methods inherited from ClassificationModel

#clear_threshold

Constructor Details

#initialize(*args) ⇒ LogisticRegressionModel

Returns a new instance of LogisticRegressionModel.



62
63
64
65
# File 'lib/spark/mllib/classification/logistic_regression.rb', line 62

def initialize(*args)
  super
  @threshold = 0.5
end

Instance Method Details

#predict(vector) ⇒ Object

Predict values for a single data point or an RDD of points using the model trained.



69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# File 'lib/spark/mllib/classification/logistic_regression.rb', line 69

def predict(vector)
  vector = Spark::Mllib::Vectors.to_vector(vector)
  margin = weights.dot(vector) + intercept
  score = 1.0 / (1.0 + Math.exp(-margin))

  if threshold.nil?
    return score
  end

  if score > threshold
    1
  else
    0
  end
end