Class: SVMKit::NearestNeighbors::KNeighborsClassifier
- Inherits:
-
Object
- Object
- SVMKit::NearestNeighbors::KNeighborsClassifier
- Includes:
- Base::BaseEstimator, Base::Classifier
- Defined in:
- lib/svmkit/nearest_neighbors/k_neighbors_classifier.rb
Overview
KNeighborsClassifier is a class that implements the classifier with the k-nearest neighbors rule. The current implementation uses the Euclidean distance for finding the neighbors.
Instance Attribute Summary collapse
-
#classes ⇒ Numo::Int32
readonly
Return the class labels.
-
#labels ⇒ Numo::Int32
readonly
Return the labels of the prototypes.
-
#prototypes ⇒ Numo::DFloat
readonly
Return the prototypes for the nearest neighbor classifier.
Attributes included from Base::BaseEstimator
Instance Method Summary collapse
-
#decision_function(x) ⇒ Numo::DFloat
Calculate confidence scores for samples.
-
#fit(x, y) ⇒ KNeighborsClassifier
Fit the model with given training data.
-
#initialize(n_neighbors: 5) ⇒ KNeighborsClassifier
constructor
Create a new classifier with the nearest neighbor rule.
-
#marshal_dump ⇒ Hash
Dump marshal data.
-
#marshal_load(obj) ⇒ nil
Load marshal data.
-
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
-
#score(x, y) ⇒ Float
Claculate the mean accuracy of the given testing data.
Constructor Details
#initialize(n_neighbors: 5) ⇒ KNeighborsClassifier
Create a new classifier with the nearest neighbor rule.
35 36 37 38 39 40 41 |
# File 'lib/svmkit/nearest_neighbors/k_neighbors_classifier.rb', line 35 def initialize(n_neighbors: 5) @params = {} @params[:n_neighbors] = n_neighbors @prototypes = nil @labels = nil @classes = nil end |
Instance Attribute Details
#classes ⇒ Numo::Int32 (readonly)
Return the class labels.
30 31 32 |
# File 'lib/svmkit/nearest_neighbors/k_neighbors_classifier.rb', line 30 def classes @classes end |
#labels ⇒ Numo::Int32 (readonly)
Return the labels of the prototypes
26 27 28 |
# File 'lib/svmkit/nearest_neighbors/k_neighbors_classifier.rb', line 26 def labels @labels end |
#prototypes ⇒ Numo::DFloat (readonly)
Return the prototypes for the nearest neighbor classifier.
22 23 24 |
# File 'lib/svmkit/nearest_neighbors/k_neighbors_classifier.rb', line 22 def prototypes @prototypes end |
Instance Method Details
#decision_function(x) ⇒ Numo::DFloat
Calculate confidence scores for samples.
59 60 61 62 63 64 65 66 67 68 69 70 |
# File 'lib/svmkit/nearest_neighbors/k_neighbors_classifier.rb', line 59 def decision_function(x) distance_matrix = PairwiseMetric.euclidean_distance(x, @prototypes) n_samples, n_prototypes = distance_matrix.shape n_classes = @classes.size n_neighbors = [@params[:n_neighbors], n_prototypes].min scores = Numo::DFloat.zeros(n_samples, n_classes) n_samples.times do |m| neighbor_ids = distance_matrix[m, true].to_a.each_with_index.sort.map(&:last)[0...n_neighbors] neighbor_ids.each { |n| scores[m, @classes.to_a.index(@labels[n])] += 1.0 } end scores end |
#fit(x, y) ⇒ KNeighborsClassifier
Fit the model with given training data.
48 49 50 51 52 53 |
# File 'lib/svmkit/nearest_neighbors/k_neighbors_classifier.rb', line 48 def fit(x, y) @prototypes = Numo::DFloat.asarray(x.to_a) @labels = Numo::Int32.asarray(y.to_a) @classes = Numo::Int32.asarray(y.to_a.uniq.sort) self end |
#marshal_dump ⇒ Hash
Dump marshal data.
95 96 97 98 99 100 |
# File 'lib/svmkit/nearest_neighbors/k_neighbors_classifier.rb', line 95 def marshal_dump { params: params, prototypes: @prototypes, labels: @labels, classes: @classes } end |
#marshal_load(obj) ⇒ nil
Load marshal data.
104 105 106 107 108 109 110 |
# File 'lib/svmkit/nearest_neighbors/k_neighbors_classifier.rb', line 104 def marshal_load(obj) @params = obj[:params] @prototypes = obj[:prototypes] @labels = obj[:labels] @classes = obj[:classes] nil end |
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
76 77 78 79 80 |
# File 'lib/svmkit/nearest_neighbors/k_neighbors_classifier.rb', line 76 def predict(x) n_samples = x.shape.first decision_values = decision_function(x) Numo::Int32.asarray(Array.new(n_samples) { |n| @classes[decision_values[n, true].max_index] }) end |
#score(x, y) ⇒ Float
Claculate the mean accuracy of the given testing data.
87 88 89 90 91 |
# File 'lib/svmkit/nearest_neighbors/k_neighbors_classifier.rb', line 87 def score(x, y) p = predict(x) n_hits = (y.to_a.map.with_index { |l, n| l == p[n] ? 1 : 0 }).inject(:+) n_hits / y.size.to_f end |