Class: PytorchLightningModel
- Inherits:
-
TorchModel
- Object
- VectorModel
- PythonModel
- TorchModel
- PytorchLightningModel
- Defined in:
- lib/rbbt/vector/model/pytorch_lightning.rb
Instance Attribute Summary collapse
-
#loader ⇒ Object
Returns the value of attribute loader.
-
#trainer ⇒ Object
Returns the value of attribute trainer.
-
#val_loader ⇒ Object
Returns the value of attribute val_loader.
Attributes inherited from TorchModel
Attributes inherited from PythonModel
Attributes inherited from VectorModel
#balance, #bar, #directory, #eval_model, #extract_features, #factor_levels, #features, #init_model, #labels, #model, #model_options, #model_path, #names, #post_process, #train_model
Instance Method Summary collapse
-
#initialize ⇒ PytorchLightningModel
constructor
A new instance of PytorchLightningModel.
Methods inherited from TorchModel
device, dtype, feature_dataset, feature_tsv, freeze, #freeze_layer, freeze_layer, get_layer, #get_layer, #get_weights, get_weights, init_python, load_architecture, load_state, model_architecture, optimizer, #reset_model, save_architecture, save_state, tensor, text_dataset
Methods inherited from VectorModel
R_eval, R_run, R_train, #__load_method, #add, #add_list, #balance_labels, #clear, #cross_validation, #eval, #eval_list, f1_metrics, #init, #run, #save_models, #train
Constructor Details
#initialize ⇒ PytorchLightningModel
Returns a new instance of PytorchLightningModel.
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
# File 'lib/rbbt/vector/model/pytorch_lightning.rb', line 5 def initialize(...) super(...) train_model do |features,labels| model = init train_loader = self.loader val_loader = self.val_loader if train_loader.nil? batch_size ||= [:training_args][:batch_size] batch_size ||= [:batch_size] batch_size ||= 1 shuffle = [:training_args][:shuffle] shuffle = true if shuffle.nil? num_workers = Rbbt.config(:num_workers, :dataloader, :default => 2) train_loader = RbbtPython.run :torch do dataset = features.zip(labels).collect{|f,l| [torch.tensor(f), l] } torch.utils.data.DataLoader.call(dataset, batch_size: batch_size, shuffle: shuffle, num_workers: num_workers.to_i) end end trainer.fit(model, train_loader, val_loader) TorchModel.save_architecture(model, model_path) if @directory TorchModel.save_state(model, model_path) if @directory end eval_model do |features,list=false| model = init eval_loader = self.loader if list if eval_loader.nil? batch_size ||= [:batch_size] batch_size ||= [:training_args][:batch_size] batch_size ||= 1 num_workers = Rbbt.config(:num_workers, :dataloader, :default => 2) eval_loader = RbbtPython.run :torch do dataset = torch.tensor(features) torch.utils.data.DataLoader.call(dataset, batch_size: batch_size, num_workers: num_workers.to_i) end end trainer.predict(model, eval_loader).inject([]){|acc,res| acc.concat RbbtPython.numpy2ruby(res[1])} else model.call(torch.tensor(features)) end end end |
Instance Attribute Details
#loader ⇒ Object
Returns the value of attribute loader.
4 5 6 |
# File 'lib/rbbt/vector/model/pytorch_lightning.rb', line 4 def loader @loader end |
#trainer ⇒ Object
Returns the value of attribute trainer.
4 5 6 |
# File 'lib/rbbt/vector/model/pytorch_lightning.rb', line 4 def trainer @trainer end |
#val_loader ⇒ Object
Returns the value of attribute val_loader.
4 5 6 |
# File 'lib/rbbt/vector/model/pytorch_lightning.rb', line 4 def val_loader @val_loader end |