Class: Rumale::Multiclass::OneVsRestClassifier
- Inherits:
-
Object
- Object
- Rumale::Multiclass::OneVsRestClassifier
- Includes:
- Base::BaseEstimator, Base::Classifier
- Defined in:
- lib/rumale/multiclass/one_vs_rest_classifier.rb
Overview
All classifier in Rumale support multi-class classifiction since version 0.2.7. There is no need to explicitly use this class for multiclass classifiction.
OneVsRestClassifier is a class that implements One-vs-Rest (OvR) strategy for multi-class classification.
Instance Attribute Summary collapse
-
#classes ⇒ Numo::Int32
readonly
Return the class labels.
-
#estimators ⇒ Array<Classifier>
readonly
Return the set of estimators.
Attributes included from Base::BaseEstimator
Instance Method Summary collapse
-
#decision_function(x) ⇒ Numo::DFloat
Calculate confidence scores for samples.
-
#fit(x, y) ⇒ OneVsRestClassifier
Fit the model with given training data.
-
#initialize(estimator: nil) ⇒ OneVsRestClassifier
constructor
Create a new multi-class classifier with the one-vs-rest startegy.
-
#marshal_dump ⇒ Hash
Dump marshal data.
-
#marshal_load(obj) ⇒ nil
Load marshal data.
-
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
Methods included from Base::Classifier
Constructor Details
#initialize(estimator: nil) ⇒ OneVsRestClassifier
Create a new multi-class classifier with the one-vs-rest startegy.
35 36 37 38 39 40 41 |
# File 'lib/rumale/multiclass/one_vs_rest_classifier.rb', line 35 def initialize(estimator: nil) check_params_type(Rumale::Base::BaseEstimator, estimator: estimator) @params = {} @params[:estimator] = estimator @estimators = nil @classes = nil end |
Instance Attribute Details
#classes ⇒ Numo::Int32 (readonly)
Return the class labels.
30 31 32 |
# File 'lib/rumale/multiclass/one_vs_rest_classifier.rb', line 30 def classes @classes end |
#estimators ⇒ Array<Classifier> (readonly)
Return the set of estimators.
26 27 28 |
# File 'lib/rumale/multiclass/one_vs_rest_classifier.rb', line 26 def estimators @estimators end |
Instance Method Details
#decision_function(x) ⇒ Numo::DFloat
Calculate confidence scores for samples.
65 66 67 68 69 |
# File 'lib/rumale/multiclass/one_vs_rest_classifier.rb', line 65 def decision_function(x) check_sample_array(x) n_classes = @classes.size Numo::DFloat.asarray(Array.new(n_classes) { |m| @estimators[m].decision_function(x).to_a }).transpose end |
#fit(x, y) ⇒ OneVsRestClassifier
Fit the model with given training data.
48 49 50 51 52 53 54 55 56 57 58 59 |
# File 'lib/rumale/multiclass/one_vs_rest_classifier.rb', line 48 def fit(x, y) check_sample_array(x) check_label_array(y) check_sample_label_size(x, y) y_arr = y.to_a @classes = Numo::Int32.asarray(y_arr.uniq.sort) @estimators = @classes.to_a.map do |label| bin_y = Numo::Int32.asarray(y_arr.map { |l| l == label ? 1 : -1 }) @params[:estimator].dup.fit(x, bin_y) end self end |
#marshal_dump ⇒ Hash
Dump marshal data.
84 85 86 87 88 |
# File 'lib/rumale/multiclass/one_vs_rest_classifier.rb', line 84 def marshal_dump { params: @params, classes: @classes, estimators: @estimators.map { |e| Marshal.dump(e) } } end |
#marshal_load(obj) ⇒ nil
Load marshal data.
92 93 94 95 96 97 |
# File 'lib/rumale/multiclass/one_vs_rest_classifier.rb', line 92 def marshal_load(obj) @params = obj[:params] @classes = obj[:classes] @estimators = obj[:estimators].map { |e| Marshal.load(e) } nil end |
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
75 76 77 78 79 80 |
# File 'lib/rumale/multiclass/one_vs_rest_classifier.rb', line 75 def predict(x) check_sample_array(x) n_samples, = x.shape decision_values = decision_function(x) Numo::Int32.asarray(Array.new(n_samples) { |n| @classes[decision_values[n, true].max_index] }) end |