Class: Rumale::Clustering::KMeans
- Inherits:
-
Object
- Object
- Rumale::Clustering::KMeans
- Includes:
- Base::BaseEstimator, Base::ClusterAnalyzer
- Defined in:
- lib/rumale/clustering/k_means.rb
Overview
KMeans is a class that implements K-Means cluster analysis. The current implementation uses the Euclidean distance for analyzing the clusters.
Reference
-
Arthur and S. Vassilvitskii, “k-means++: the advantages of careful seeding,” Proc. SODA’07, pp. 1027–1035, 2007.
-
Instance Attribute Summary collapse
-
#cluster_centers ⇒ Numo::DFloat
readonly
Return the centroids.
-
#rng ⇒ Random
readonly
Return the random generator.
Attributes included from Base::BaseEstimator
Instance Method Summary collapse
-
#fit(x) ⇒ KMeans
Analysis clusters with given training data.
-
#fit_predict(x) ⇒ Numo::Int32
Analysis clusters and assign samples to clusters.
-
#initialize(n_clusters: 8, init: 'k-means++', max_iter: 50, tol: 1.0e-4, random_seed: nil) ⇒ KMeans
constructor
Create a new cluster analyzer with K-Means method.
-
#marshal_dump ⇒ Hash
Dump marshal data.
-
#marshal_load(obj) ⇒ nil
Load marshal data.
-
#predict(x) ⇒ Numo::Int32
Predict cluster labels for samples.
Methods included from Base::ClusterAnalyzer
Constructor Details
#initialize(n_clusters: 8, init: 'k-means++', max_iter: 50, tol: 1.0e-4, random_seed: nil) ⇒ KMeans
Create a new cluster analyzer with K-Means method.
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
# File 'lib/rumale/clustering/k_means.rb', line 38 def initialize(n_clusters: 8, init: 'k-means++', max_iter: 50, tol: 1.0e-4, random_seed: nil) check_params_integer(n_clusters: n_clusters, max_iter: max_iter) check_params_float(tol: tol) check_params_string(init: init) check_params_type_or_nil(Integer, random_seed: random_seed) check_params_positive(n_clusters: n_clusters, max_iter: max_iter) @params = {} @params[:n_clusters] = n_clusters @params[:init] = init == 'random' ? 'random' : 'k-means++' @params[:max_iter] = max_iter @params[:tol] = tol @params[:random_seed] = random_seed @params[:random_seed] ||= srand @cluster_centers = nil @rng = Random.new(@params[:random_seed]) end |
Instance Attribute Details
#cluster_centers ⇒ Numo::DFloat (readonly)
Return the centroids.
25 26 27 |
# File 'lib/rumale/clustering/k_means.rb', line 25 def cluster_centers @cluster_centers end |
#rng ⇒ Random (readonly)
Return the random generator.
29 30 31 |
# File 'lib/rumale/clustering/k_means.rb', line 29 def rng @rng end |
Instance Method Details
#fit(x) ⇒ KMeans
Analysis clusters with given training data.
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
# File 'lib/rumale/clustering/k_means.rb', line 61 def fit(x, _y = nil) check_sample_array(x) init_cluster_centers(x) @params[:max_iter].times do |_t| cluster_labels = assign_cluster(x) old_centers = @cluster_centers.dup @params[:n_clusters].times do |n| assigned_bits = cluster_labels.eq(n) @cluster_centers[n, true] = x[assigned_bits.where, true].mean(axis: 0) if assigned_bits.count.positive? end error = Numo::NMath.sqrt(((old_centers - @cluster_centers)**2).sum(axis: 1)).mean break if error <= @params[:tol] end self end |
#fit_predict(x) ⇒ Numo::Int32
Analysis clusters and assign samples to clusters.
90 91 92 93 94 |
# File 'lib/rumale/clustering/k_means.rb', line 90 def fit_predict(x) check_sample_array(x) fit(x) predict(x) end |
#marshal_dump ⇒ Hash
Dump marshal data.
98 99 100 101 102 |
# File 'lib/rumale/clustering/k_means.rb', line 98 def marshal_dump { params: @params, cluster_centers: @cluster_centers, rng: @rng } end |
#marshal_load(obj) ⇒ nil
Load marshal data.
106 107 108 109 110 111 |
# File 'lib/rumale/clustering/k_means.rb', line 106 def marshal_load(obj) @params = obj[:params] @cluster_centers = obj[:cluster_centers] @rng = obj[:rng] nil end |
#predict(x) ⇒ Numo::Int32
Predict cluster labels for samples.
81 82 83 84 |
# File 'lib/rumale/clustering/k_means.rb', line 81 def predict(x) check_sample_array(x) assign_cluster(x) end |