Class: Spark::Mllib::LogisticRegressionModel
- Inherits:
-
ClassificationModel
- Object
- ClassificationModel
- Spark::Mllib::LogisticRegressionModel
- Defined in:
- lib/spark/mllib/classification/logistic_regression.rb
Overview
LogisticRegressionModel
A linear binary classification model derived from logistic regression.
Examples:
Spark::Mllib.import
# Dense vectors
data = [
LabeledPoint.new(0.0, [0.0, 1.0]),
LabeledPoint.new(1.0, [1.0, 0.0]),
]
lrm = LogisticRegressionWithSGD.train($sc.parallelize(data))
lrm.predict([1.0, 0.0])
# => 1
lrm.predict([0.0, 1.0])
# => 0
lrm.clear_threshold
lrm.predict([0.0, 1.0])
# => 0.123...
# Sparse vectors
data = [
LabeledPoint.new(0.0, SparseVector.new(2, {0 => 0.0})),
LabeledPoint.new(1.0, SparseVector.new(2, {1 => 1.0})),
LabeledPoint.new(0.0, SparseVector.new(2, {0 => 1.0})),
LabeledPoint.new(1.0, SparseVector.new(2, {1 => 2.0}))
]
lrm = LogisticRegressionWithSGD.train($sc.parallelize(data))
lrm.predict([0.0, 1.0])
# => 1
lrm.predict([1.0, 0.0])
# => 0
lrm.predict(SparseVector.new(2, {1 => 1.0}))
# => 1
lrm.predict(SparseVector.new(2, {0 => 1.0}))
# => 0
# LogisticRegressionWithLBFGS
data = [
LabeledPoint.new(0.0, [0.0, 1.0]),
LabeledPoint.new(1.0, [1.0, 0.0]),
]
lrm = LogisticRegressionWithLBFGS.train($sc.parallelize(data))
lrm.predict([1.0, 0.0])
# => 1
lrm.predict([0.0, 1.0])
# => 0
Instance Attribute Summary
Attributes inherited from ClassificationModel
#intercept, #threshold, #weights
Instance Method Summary collapse
-
#initialize(*args) ⇒ LogisticRegressionModel
constructor
A new instance of LogisticRegressionModel.
-
#predict(vector) ⇒ Object
Predict values for a single data point or an RDD of points using the model trained.
Methods inherited from ClassificationModel
Constructor Details
#initialize(*args) ⇒ LogisticRegressionModel
Returns a new instance of LogisticRegressionModel.
62 63 64 65 |
# File 'lib/spark/mllib/classification/logistic_regression.rb', line 62 def initialize(*args) super @threshold = 0.5 end |
Instance Method Details
#predict(vector) ⇒ Object
Predict values for a single data point or an RDD of points using the model trained.
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
# File 'lib/spark/mllib/classification/logistic_regression.rb', line 69 def predict(vector) vector = Spark::Mllib::Vectors.to_vector(vector) margin = weights.dot(vector) + intercept score = 1.0 / (1.0 + Math.exp(-margin)) if threshold.nil? return score end if score > threshold 1 else 0 end end |