Module: EasyML::Core::Models::XGBoostCore
- Included in:
- XGBoost, Models::XGBoost
- Defined in:
- lib/easy_ml/core/models/xgboost_core.rb
Constant Summary collapse
- OBJECTIVES =
{ classification: { binary: %w[binary:logistic binary:hinge], multi_class: %w[multi:softmax multi:softprob] }, regression: %w[reg:squarederror reg:logistic] }
Instance Attribute Summary collapse
-
#booster ⇒ Object
Returns the value of attribute booster.
-
#model ⇒ Object
Returns the value of attribute model.
Class Method Summary collapse
Instance Method Summary collapse
- #_save_model_file(path) ⇒ Object
- #base_model ⇒ Object
- #customize_callbacks {|callbacks| ... } ⇒ Object
- #feature_importances ⇒ Object
- #load(path = nil) ⇒ Object
- #predict(xs) ⇒ Object
- #predict_proba(data) ⇒ Object
Instance Attribute Details
#booster ⇒ Object
Returns the value of attribute booster.
41 42 43 |
# File 'lib/easy_ml/core/models/xgboost_core.rb', line 41 def booster @booster end |
#model ⇒ Object
Returns the value of attribute model.
41 42 43 |
# File 'lib/easy_ml/core/models/xgboost_core.rb', line 41 def model @model end |
Class Method Details
.included(base) ⇒ Object
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
# File 'lib/easy_ml/core/models/xgboost_core.rb', line 14 def self.included(base) base.class_eval do attribute :evaluator dependency :callbacks, { array: true } do |dep| dep.option :wandb do |opt| opt.set_class Wandb::XGBoostCallback opt.bind_attribute :log_model, default: false opt.bind_attribute :log_feature_importance, default: true opt.bind_attribute :importance_type, default: "gain" opt.bind_attribute :define_metric, default: true opt.bind_attribute :project_name end end dependency :hyperparameters do |dep| dep.set_class EasyML::Models::Hyperparameters::XGBoost dep.bind_attribute :batch_size, default: 32 dep.bind_attribute :learning_rate, default: 1.1 dep.bind_attribute :max_depth, default: 6 dep.bind_attribute :n_estimators, default: 100 dep.bind_attribute :booster, default: "gbtree" dep.bind_attribute :objective, default: "reg:squarederror" end end end |
Instance Method Details
#_save_model_file(path) ⇒ Object
80 81 82 83 |
# File 'lib/easy_ml/core/models/xgboost_core.rb', line 80 def _save_model_file(path) puts "XGBoost received path #{path}" @booster.save_model(path) end |
#base_model ⇒ Object
89 90 91 |
# File 'lib/easy_ml/core/models/xgboost_core.rb', line 89 def base_model ::XGBoost end |
#customize_callbacks {|callbacks| ... } ⇒ Object
93 94 95 |
# File 'lib/easy_ml/core/models/xgboost_core.rb', line 93 def customize_callbacks yield callbacks end |
#feature_importances ⇒ Object
85 86 87 |
# File 'lib/easy_ml/core/models/xgboost_core.rb', line 85 def feature_importances @model.booster.feature_names.zip(@model.feature_importances).to_h end |
#load(path = nil) ⇒ Object
69 70 71 72 73 74 75 76 77 78 |
# File 'lib/easy_ml/core/models/xgboost_core.rb', line 69 def load(path = nil) path ||= file path = path&.file&.file if path.class.ancestors.include?(CarrierWave::Uploader::Base) raise "No existing model at #{path}" unless File.exist?(path) initialize_model do booster_class.new(params: hyperparameters.to_h, model_file: path) end end |
#predict(xs) ⇒ Object
43 44 45 46 47 48 49 50 51 52 53 54 55 |
# File 'lib/easy_ml/core/models/xgboost_core.rb', line 43 def predict(xs) raise "No trained model! Train a model before calling predict" unless @booster.present? raise "Cannot predict on nil — XGBoost" if xs.nil? y_pred = @booster.predict(preprocess(xs)) case task.to_sym when :classification to_classification(y_pred) else y_pred end end |
#predict_proba(data) ⇒ Object
57 58 59 60 61 62 63 64 65 66 67 |
# File 'lib/easy_ml/core/models/xgboost_core.rb', line 57 def predict_proba(data) dmat = DMatrix.new(data) y_pred = @booster.predict(dmat) if y_pred.first.is_a?(Array) # multiple classes y_pred else y_pred.map { |v| [1 - v, v] } end end |