Class: Rumale::ModelSelection::GridSearchCV
- Inherits:
-
Object
- Object
- Rumale::ModelSelection::GridSearchCV
- Includes:
- Base::BaseEstimator
- Defined in:
- lib/rumale/model_selection/grid_search_cv.rb
Overview
GridSearchCV is a class that performs hyperparameter optimization with grid search method.
Instance Attribute Summary collapse
-
#best_estimator ⇒ Estimator
readonly
Return the estimator learned with the best parameter.
-
#best_index ⇒ Integer
readonly
Return the index of the best parameter.
-
#best_params ⇒ Hash
readonly
Return the best parameter set.
-
#best_score ⇒ Float
readonly
Return the score of the estimator learned with the best parameter.
-
#cv_results ⇒ Hash
readonly
Return the result of cross validation for each parameter.
Attributes included from Base::BaseEstimator
Instance Method Summary collapse
-
#decision_function(x) ⇒ Numo::DFloat
Call the decision_function method of learned estimator with the best parameter.
-
#fit(x, y) ⇒ GridSearchCV
Fit the model with given training data and all sets of parameters.
-
#initialize(estimator: nil, param_grid: nil, splitter: nil, evaluator: nil, greater_is_better: true) ⇒ GridSearchCV
constructor
Create a new grid search method.
-
#marshal_dump ⇒ Hash
Dump marshal data.
-
#marshal_load(obj) ⇒ nil
Load marshal data.
-
#predict(x) ⇒ Numo::NArray
Call the predict method of learned estimator with the best parameter.
-
#predict_log_proba(x) ⇒ Numo::DFloat
Call the predict_log_proba method of learned estimator with the best parameter.
-
#predict_proba(x) ⇒ Numo::DFloat
Call the predict_proba method of learned estimator with the best parameter.
-
#score(x, y) ⇒ Float
Call the score method of learned estimator with the best parameter.
Constructor Details
#initialize(estimator: nil, param_grid: nil, splitter: nil, evaluator: nil, greater_is_better: true) ⇒ GridSearchCV
Create a new grid search method.
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 67 def initialize(estimator: nil, param_grid: nil, splitter: nil, evaluator: nil, greater_is_better: true) check_params_type(Rumale::Base::BaseEstimator, estimator: estimator) check_params_type(Rumale::Base::Splitter, splitter: splitter) check_params_type_or_nil(Rumale::Base::Evaluator, evaluator: evaluator) check_params_boolean(greater_is_better: greater_is_better) @params = {} @params[:param_grid] = valid_param_grid(param_grid) @params[:estimator] = Marshal.load(Marshal.dump(estimator)) @params[:splitter] = Marshal.load(Marshal.dump(splitter)) @params[:evaluator] = Marshal.load(Marshal.dump(evaluator)) @params[:greater_is_better] = greater_is_better @cv_results = nil @best_score = nil @best_params = nil @best_index = nil @best_estimator = nil end |
Instance Attribute Details
#best_estimator ⇒ Estimator (readonly)
Return the estimator learned with the best parameter.
55 56 57 |
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 55 def best_estimator @best_estimator end |
#best_index ⇒ Integer (readonly)
Return the index of the best parameter.
51 52 53 |
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 51 def best_index @best_index end |
#best_params ⇒ Hash (readonly)
Return the best parameter set.
47 48 49 |
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 47 def best_params @best_params end |
#best_score ⇒ Float (readonly)
Return the score of the estimator learned with the best parameter.
43 44 45 |
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 43 def best_score @best_score end |
#cv_results ⇒ Hash (readonly)
Return the result of cross validation for each parameter.
39 40 41 |
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 39 def cv_results @cv_results end |
Instance Method Details
#decision_function(x) ⇒ Numo::DFloat
Call the decision_function method of learned estimator with the best parameter.
113 114 115 116 |
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 113 def decision_function(x) check_sample_array(x) @best_estimator.decision_function(x) end |
#fit(x, y) ⇒ GridSearchCV
Fit the model with given training data and all sets of parameters.
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 90 def fit(x, y) check_sample_array(x) init_attrs param_combinations.each do |prm_set| prm_set.each do |prms| report = perform_cross_validation(x, y, prms) store_cv_result(prms, report) end end find_best_params @best_estimator = configurated_estimator(@best_params) @best_estimator.fit(x, y) self end |
#marshal_dump ⇒ Hash
Dump marshal data.
157 158 159 160 161 162 163 164 |
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 157 def marshal_dump { params: @params, cv_results: @cv_results, best_score: @best_score, best_params: @best_params, best_index: @best_index, best_estimator: @best_estimator } end |
#marshal_load(obj) ⇒ nil
Load marshal data.
168 169 170 171 172 173 174 175 176 |
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 168 def marshal_load(obj) @params = obj[:params] @cv_results = obj[:cv_results] @best_score = obj[:best_score] @best_params = obj[:best_params] @best_index = obj[:best_index] @best_estimator = obj[:best_estimator] nil end |
#predict(x) ⇒ Numo::NArray
Call the predict method of learned estimator with the best parameter.
122 123 124 125 |
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 122 def predict(x) check_sample_array(x) @best_estimator.predict(x) end |
#predict_log_proba(x) ⇒ Numo::DFloat
Call the predict_log_proba method of learned estimator with the best parameter.
131 132 133 134 |
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 131 def predict_log_proba(x) check_sample_array(x) @best_estimator.predict_log_proba(x) end |
#predict_proba(x) ⇒ Numo::DFloat
Call the predict_proba method of learned estimator with the best parameter.
140 141 142 143 |
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 140 def predict_proba(x) check_sample_array(x) @best_estimator.predict_proba(x) end |
#score(x, y) ⇒ Float
Call the score method of learned estimator with the best parameter.
150 151 152 153 |
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 150 def score(x, y) check_sample_array(x) @best_estimator.score(x, y) end |