Class: Spark::Mllib::NaiveBayesModel

Inherits:
Object
  • Object
show all
Defined in:
lib/spark/mllib/classification/naive_bayes.rb

Overview

NaiveBayesModel

Model for Naive Bayes classifiers.

Contains two parameters:

pi

vector of logs of class priors (dimension C)

theta

matrix of logs of class conditional probabilities (CxD)

Examples:

Spark::Mllib.import

# Dense vectors
data = [
  LabeledPoint.new(0.0, [0.0, 0.0]),
  LabeledPoint.new(0.0, [0.0, 1.0]),
  LabeledPoint.new(1.0, [1.0, 0.0])
]
model = NaiveBayes.train($sc.parallelize(data))

model.predict([0.0, 1.0])
# => 0.0
model.predict([1.0, 0.0])
# => 1.0

# Sparse vectors
data = [
  LabeledPoint.new(0.0, SparseVector.new(2, {1 => 0.0})),
  LabeledPoint.new(0.0, SparseVector.new(2, {1 => 1.0})),
  LabeledPoint.new(1.0, SparseVector.new(2, {0 => 1.0}))
]
model = NaiveBayes.train($sc.parallelize(data))

model.predict(SparseVector.new(2, {1 => 1.0}))
# => 0.0
model.predict(SparseVector.new(2, {0 => 1.0}))
# => 1.0

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(labels, pi, theta) ⇒ NaiveBayesModel

Returns a new instance of NaiveBayesModel.



47
48
49
50
51
# File 'lib/spark/mllib/classification/naive_bayes.rb', line 47

def initialize(labels, pi, theta)
  @labels = labels
  @pi = pi
  @theta = theta
end

Instance Attribute Details

#labelsObject (readonly)

Returns the value of attribute labels.



45
46
47
# File 'lib/spark/mllib/classification/naive_bayes.rb', line 45

def labels
  @labels
end

#piObject (readonly)

Returns the value of attribute pi.



45
46
47
# File 'lib/spark/mllib/classification/naive_bayes.rb', line 45

def pi
  @pi
end

#thetaObject (readonly)

Returns the value of attribute theta.



45
46
47
# File 'lib/spark/mllib/classification/naive_bayes.rb', line 45

def theta
  @theta
end

Instance Method Details

#predict(vector) ⇒ Object

Predict values for a single data point or an RDD of points using the model trained.



55
56
57
58
59
60
# File 'lib/spark/mllib/classification/naive_bayes.rb', line 55

def predict(vector)
  vector = Spark::Mllib::Vectors.to_vector(vector)
  array = (vector.dot(theta) + pi).to_a
  index = array.index(array.max)
  labels[index]
end