66
67
68
69
70
71
72
73
74
75
76
77
78
|
# File 'lib/spark/mllib/clustering/gaussian_mixture.rb', line 66
def self.train(rdd, k, convergence_tol: 0.001, max_iterations: 100, seed: nil)
weights, means, sigmas = Spark.jb.call(RubyMLLibAPI.new, 'trainGaussianMixtureModel', rdd,
k, convergence_tol, max_iterations, Spark.jb.to_long(seed))
means.map! {|mu| Spark.jb.java_to_ruby(mu)}
sigmas.map!{|sigma| Spark.jb.java_to_ruby(sigma)}
mvgs = Array.new(k) do |i|
MultivariateGaussian.new(means[i], sigmas[i])
end
GaussianMixtureModel.new(weights, mvgs)
end
|