Class: Rumale::Clustering::GaussianMixture

Inherits:
Object
  • Object
show all
Includes:
Base::BaseEstimator, Base::ClusterAnalyzer
Defined in:
lib/rumale/clustering/gaussian_mixture.rb

Overview

GaussianMixture is a class that implements cluster analysis with gaussian mixture model. The current implementation uses only the diagonal elements of covariance matrices to represent mixture parameters without using full elements.

Examples:

analyzer = Rumale::Clustering::GaussianMixture.new(n_clusters: 10, max_iter: 50)
cluster_labels = analyzer.fit_predict(samples)

Instance Attribute Summary collapse

Attributes included from Base::BaseEstimator

#params

Instance Method Summary collapse

Methods included from Base::ClusterAnalyzer

#score

Constructor Details

#initialize(n_clusters: 8, init: 'k-means++', max_iter: 50, tol: 1.0e-4, reg_covar: 1.0e-6, random_seed: nil) ⇒ GaussianMixture

Create a new cluster analyzer with gaussian mixture model.

Parameters:

  • n_clusters (Integer) (defaults to: 8)

    The number of clusters.

  • init (String) (defaults to: 'k-means++')

    The initialization method for centroids (‘random’ or ‘k-means++’).

  • max_iter (Integer) (defaults to: 50)

    The maximum number of iterations.

  • tol (Float) (defaults to: 1.0e-4)

    The tolerance of termination criterion.

  • reg_covar (Float) (defaults to: 1.0e-6)

    The non-negative regularization to the diagonal of covariance.

  • random_seed (Integer) (defaults to: nil)

    The seed value using to initialize the random generator.



44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# File 'lib/rumale/clustering/gaussian_mixture.rb', line 44

def initialize(n_clusters: 8, init: 'k-means++', max_iter: 50, tol: 1.0e-4, reg_covar: 1.0e-6, random_seed: nil)
  check_params_integer(n_clusters: n_clusters, max_iter: max_iter)
  check_params_float(tol: tol)
  check_params_string(init: init)
  check_params_type_or_nil(Integer, random_seed: random_seed)
  check_params_positive(n_clusters: n_clusters, max_iter: max_iter)
  @params = {}
  @params[:n_clusters] = n_clusters
  @params[:init] = init == 'random' ? 'random' : 'k-means++'
  @params[:max_iter] = max_iter
  @params[:tol] = tol
  @params[:reg_covar] = reg_covar
  @params[:random_seed] = random_seed
  @params[:random_seed] ||= srand
  @n_iter = nil
  @weights = nil
  @means = nil
  @covariances = nil
end

Instance Attribute Details

#covariancesNumo::DFloat (readonly)

Return the diagonal elements of covariance matrix of each cluster.

Returns:

  • (Numo::DFloat)

    (shape: [n_clusters, n_features])



34
35
36
# File 'lib/rumale/clustering/gaussian_mixture.rb', line 34

def covariances
  @covariances
end

#meansNumo::DFloat (readonly)

Return the mean of each cluster.

Returns:

  • (Numo::DFloat)

    (shape: [n_clusters, n_features])



30
31
32
# File 'lib/rumale/clustering/gaussian_mixture.rb', line 30

def means
  @means
end

#n_iterInteger (readonly)

Return the number of iterations to covergence.

Returns:

  • (Integer)


22
23
24
# File 'lib/rumale/clustering/gaussian_mixture.rb', line 22

def n_iter
  @n_iter
end

#weightsNumo::DFloat (readonly)

Return the weight of each cluster.

Returns:

  • (Numo::DFloat)

    (shape: [n_clusters])



26
27
28
# File 'lib/rumale/clustering/gaussian_mixture.rb', line 26

def weights
  @weights
end

Instance Method Details

#fit(x) ⇒ GaussianMixture

Analysis clusters with given training data.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The training data to be used for cluster analysis.

Returns:



70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# File 'lib/rumale/clustering/gaussian_mixture.rb', line 70

def fit(x, _y = nil)
  check_sample_array(x)
  n_samples = x.shape[0]
  memberships = init_memberships(x)
  @params[:max_iter].times do |t|
    @n_iter = t
    @weights = calc_weights(n_samples, memberships)
    @means = calc_means(x, memberships)
    @covariances = calc_diag_covariances(x, @means, memberships) + @params[:reg_covar]
    new_memberships = calc_memberships(x, @weights, @means, @covariances)
    error = (memberships - new_memberships).abs.max
    break if error <= @params[:tol]
    memberships = new_memberships.dup
  end
  self
end

#fit_predict(x) ⇒ Numo::Int32

Analysis clusters and assign samples to clusters.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The training data to be used for cluster analysis.

Returns:

  • (Numo::Int32)

    (shape: [n_samples]) Predicted cluster label per sample.



101
102
103
104
# File 'lib/rumale/clustering/gaussian_mixture.rb', line 101

def fit_predict(x)
  check_sample_array(x)
  fit(x).predict(x)
end

#marshal_dumpHash

Dump marshal data.

Returns:

  • (Hash)

    The marshal data.



108
109
110
111
112
113
114
# File 'lib/rumale/clustering/gaussian_mixture.rb', line 108

def marshal_dump
  { params: @params,
    n_iter: @n_iter,
    weights: @weights,
    means: @means,
    covariances: @covariances }
end

#marshal_load(obj) ⇒ nil

Load marshal data.

Returns:

  • (nil)


118
119
120
121
122
123
124
125
# File 'lib/rumale/clustering/gaussian_mixture.rb', line 118

def marshal_load(obj)
  @params = obj[:params]
  @n_iter = obj[:n_iter]
  @weights = obj[:weights]
  @means = obj[:means]
  @covariances = obj[:covariances]
  nil
end

#predict(x) ⇒ Numo::Int32

Predict cluster labels for samples.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The samples to predict the cluster label.

Returns:

  • (Numo::Int32)

    (shape: [n_samples]) Predicted cluster label per sample.



91
92
93
94
95
# File 'lib/rumale/clustering/gaussian_mixture.rb', line 91

def predict(x)
  check_sample_array(x)
  memberships = calc_memberships(x, @weights, @means, @covariances)
  assign_cluster(memberships)
end