Class: SVMKit::Multiclass::OneVsRestClassifier
- Inherits:
-
Object
- Object
- SVMKit::Multiclass::OneVsRestClassifier
- Includes:
- Base::BaseEstimator, Base::Classifier
- Defined in:
- lib/svmkit/multiclass/one_vs_rest_classifier.rb
Overview
All classifier in SVMKit support multi-class classifiction since version 0.2.7. There is no need to explicitly use this class for multiclass classifiction.
OneVsRestClassifier is a class that implements One-vs-Rest (OvR) strategy for multi-class classification.
Instance Attribute Summary collapse
-
#classes ⇒ Numo::Int32
readonly
Return the class labels.
-
#estimators ⇒ Array<Classifier>
readonly
Return the set of estimators.
Attributes included from Base::BaseEstimator
Instance Method Summary collapse
-
#decision_function(x) ⇒ Numo::DFloat
Calculate confidence scores for samples.
-
#fit(x, y) ⇒ OneVsRestClassifier
Fit the model with given training data.
-
#initialize(estimator: nil) ⇒ OneVsRestClassifier
constructor
Create a new multi-class classifier with the one-vs-rest startegy.
-
#marshal_dump ⇒ Hash
Dump marshal data.
-
#marshal_load(obj) ⇒ nil
Load marshal data.
-
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
Methods included from Base::Classifier
Constructor Details
#initialize(estimator: nil) ⇒ OneVsRestClassifier
Create a new multi-class classifier with the one-vs-rest startegy.
35 36 37 38 39 40 |
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 35 def initialize(estimator: nil) @params = {} @params[:estimator] = estimator @estimators = nil @classes = nil end |
Instance Attribute Details
#classes ⇒ Numo::Int32 (readonly)
Return the class labels.
30 31 32 |
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 30 def classes @classes end |
#estimators ⇒ Array<Classifier> (readonly)
Return the set of estimators.
26 27 28 |
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 26 def estimators @estimators end |
Instance Method Details
#decision_function(x) ⇒ Numo::DFloat
Calculate confidence scores for samples.
61 62 63 64 |
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 61 def decision_function(x) n_classes = @classes.size Numo::DFloat.asarray(Array.new(n_classes) { |m| @estimators[m].decision_function(x).to_a }).transpose end |
#fit(x, y) ⇒ OneVsRestClassifier
Fit the model with given training data.
47 48 49 50 51 52 53 54 55 |
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 47 def fit(x, y) y_arr = y.to_a @classes = Numo::Int32.asarray(y_arr.uniq.sort) @estimators = @classes.to_a.map do |label| bin_y = Numo::Int32.asarray(y_arr.map { |l| l == label ? 1 : -1 }) @params[:estimator].dup.fit(x, bin_y) end self end |
#marshal_dump ⇒ Hash
Dump marshal data.
78 79 80 81 82 |
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 78 def marshal_dump { params: @params, classes: @classes, estimators: @estimators.map { |e| Marshal.dump(e) } } end |
#marshal_load(obj) ⇒ nil
Load marshal data.
86 87 88 89 90 91 |
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 86 def marshal_load(obj) @params = obj[:params] @classes = obj[:classes] @estimators = obj[:estimators].map { |e| Marshal.load(e) } nil end |
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
70 71 72 73 74 |
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 70 def predict(x) n_samples, = x.shape decision_values = decision_function(x) Numo::Int32.asarray(Array.new(n_samples) { |n| @classes[decision_values[n, true].max_index] }) end |