Class: Spark::Mllib::GaussianMixture

Inherits:
Object
  • Object
show all
Defined in:
lib/spark/mllib/clustering/gaussian_mixture.rb

Class Method Summary collapse

Class Method Details

.train(rdd, k, convergence_tol: 0.001, max_iterations: 100, seed: nil) ⇒ Object



66
67
68
69
70
71
72
73
74
75
76
77
78
# File 'lib/spark/mllib/clustering/gaussian_mixture.rb', line 66

def self.train(rdd, k, convergence_tol: 0.001, max_iterations: 100, seed: nil)
  weights, means, sigmas = Spark.jb.call(RubyMLLibAPI.new, 'trainGaussianMixtureModel', rdd,
                                         k, convergence_tol, max_iterations, Spark.jb.to_long(seed))

  means.map! {|mu|    Spark.jb.java_to_ruby(mu)}
  sigmas.map!{|sigma| Spark.jb.java_to_ruby(sigma)}

  mvgs = Array.new(k) do |i|
    MultivariateGaussian.new(means[i], sigmas[i])
  end

  GaussianMixtureModel.new(weights, mvgs)
end