Class: Rumale::Clustering::GaussianMixture
- Inherits:
-
Object
- Object
- Rumale::Clustering::GaussianMixture
- Includes:
- Base::BaseEstimator, Base::ClusterAnalyzer
- Defined in:
- lib/rumale/clustering/gaussian_mixture.rb
Overview
GaussianMixture is a class that implements cluster analysis with gaussian mixture model. The current implementation uses only the diagonal elements of covariance matrices to represent mixture parameters without using full elements.
Instance Attribute Summary collapse
-
#covariances ⇒ Numo::DFloat
readonly
Return the diagonal elements of covariance matrix of each cluster.
-
#means ⇒ Numo::DFloat
readonly
Return the mean of each cluster.
-
#n_iter ⇒ Integer
readonly
Return the number of iterations to covergence.
-
#weights ⇒ Numo::DFloat
readonly
Return the weight of each cluster.
Attributes included from Base::BaseEstimator
Instance Method Summary collapse
-
#fit(x) ⇒ GaussianMixture
Analysis clusters with given training data.
-
#fit_predict(x) ⇒ Numo::Int32
Analysis clusters and assign samples to clusters.
-
#initialize(n_clusters: 8, init: 'k-means++', max_iter: 50, tol: 1.0e-4, reg_covar: 1.0e-6, random_seed: nil) ⇒ GaussianMixture
constructor
Create a new cluster analyzer with gaussian mixture model.
-
#marshal_dump ⇒ Hash
Dump marshal data.
-
#marshal_load(obj) ⇒ nil
Load marshal data.
-
#predict(x) ⇒ Numo::Int32
Predict cluster labels for samples.
Methods included from Base::ClusterAnalyzer
Constructor Details
#initialize(n_clusters: 8, init: 'k-means++', max_iter: 50, tol: 1.0e-4, reg_covar: 1.0e-6, random_seed: nil) ⇒ GaussianMixture
Create a new cluster analyzer with gaussian mixture model.
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
# File 'lib/rumale/clustering/gaussian_mixture.rb', line 44 def initialize(n_clusters: 8, init: 'k-means++', max_iter: 50, tol: 1.0e-4, reg_covar: 1.0e-6, random_seed: nil) check_params_integer(n_clusters: n_clusters, max_iter: max_iter) check_params_float(tol: tol) check_params_string(init: init) check_params_type_or_nil(Integer, random_seed: random_seed) check_params_positive(n_clusters: n_clusters, max_iter: max_iter) @params = {} @params[:n_clusters] = n_clusters @params[:init] = init == 'random' ? 'random' : 'k-means++' @params[:max_iter] = max_iter @params[:tol] = tol @params[:reg_covar] = reg_covar @params[:random_seed] = random_seed @params[:random_seed] ||= srand @n_iter = nil @weights = nil @means = nil @covariances = nil end |
Instance Attribute Details
#covariances ⇒ Numo::DFloat (readonly)
Return the diagonal elements of covariance matrix of each cluster.
34 35 36 |
# File 'lib/rumale/clustering/gaussian_mixture.rb', line 34 def covariances @covariances end |
#means ⇒ Numo::DFloat (readonly)
Return the mean of each cluster.
30 31 32 |
# File 'lib/rumale/clustering/gaussian_mixture.rb', line 30 def means @means end |
#n_iter ⇒ Integer (readonly)
Return the number of iterations to covergence.
22 23 24 |
# File 'lib/rumale/clustering/gaussian_mixture.rb', line 22 def n_iter @n_iter end |
#weights ⇒ Numo::DFloat (readonly)
Return the weight of each cluster.
26 27 28 |
# File 'lib/rumale/clustering/gaussian_mixture.rb', line 26 def weights @weights end |
Instance Method Details
#fit(x) ⇒ GaussianMixture
Analysis clusters with given training data.
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
# File 'lib/rumale/clustering/gaussian_mixture.rb', line 70 def fit(x, _y = nil) check_sample_array(x) n_samples = x.shape[0] memberships = init_memberships(x) @params[:max_iter].times do |t| @n_iter = t @weights = calc_weights(n_samples, memberships) @means = calc_means(x, memberships) @covariances = calc_diag_covariances(x, @means, memberships) + @params[:reg_covar] new_memberships = calc_memberships(x, @weights, @means, @covariances) error = (memberships - new_memberships).abs.max break if error <= @params[:tol] memberships = new_memberships.dup end self end |
#fit_predict(x) ⇒ Numo::Int32
Analysis clusters and assign samples to clusters.
101 102 103 104 |
# File 'lib/rumale/clustering/gaussian_mixture.rb', line 101 def fit_predict(x) check_sample_array(x) fit(x).predict(x) end |
#marshal_dump ⇒ Hash
Dump marshal data.
108 109 110 111 112 113 114 |
# File 'lib/rumale/clustering/gaussian_mixture.rb', line 108 def marshal_dump { params: @params, n_iter: @n_iter, weights: @weights, means: @means, covariances: @covariances } end |
#marshal_load(obj) ⇒ nil
Load marshal data.
118 119 120 121 122 123 124 125 |
# File 'lib/rumale/clustering/gaussian_mixture.rb', line 118 def marshal_load(obj) @params = obj[:params] @n_iter = obj[:n_iter] @weights = obj[:weights] @means = obj[:means] @covariances = obj[:covariances] nil end |
#predict(x) ⇒ Numo::Int32
Predict cluster labels for samples.
91 92 93 94 95 |
# File 'lib/rumale/clustering/gaussian_mixture.rb', line 91 def predict(x) check_sample_array(x) memberships = calc_memberships(x, @weights, @means, @covariances) assign_cluster(memberships) end |