Class: PytorchLightningModel

Inherits:
TorchModel show all
Defined in:
lib/rbbt/vector/model/pytorch_lightning.rb

Instance Attribute Summary collapse

Attributes inherited from TorchModel

#criterion, #optimizer

Attributes inherited from PythonModel

#python_class, #python_module

Attributes inherited from VectorModel

#balance, #bar, #directory, #eval_model, #extract_features, #factor_levels, #features, #init_model, #labels, #model, #model_options, #model_path, #names, #post_process, #train_model

Instance Method Summary collapse

Methods inherited from TorchModel

device, dtype, feature_dataset, feature_tsv, freeze, #freeze_layer, freeze_layer, get_layer, #get_layer, #get_weights, get_weights, init_python, load_architecture, load_state, model_architecture, optimizer, #reset_model, save_architecture, save_state, tensor, text_dataset

Methods inherited from VectorModel

R_eval, R_run, R_train, #__load_method, #add, #add_list, #balance_labels, #clear, #cross_validation, #eval, #eval_list, f1_metrics, #init, #run, #save_models, #train

Constructor Details

#initializePytorchLightningModel

Returns a new instance of PytorchLightningModel.



5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# File 'lib/rbbt/vector/model/pytorch_lightning.rb', line 5

def initialize(...)
  super(...)

  train_model do |features,labels|
    model = init
    train_loader = self.loader
    val_loader = self.val_loader
    if train_loader.nil?
      batch_size ||= model_options[:training_args][:batch_size]
      batch_size ||= model_options[:batch_size]
      batch_size ||= 1

      shuffle = model_options[:training_args][:shuffle]
      shuffle = true if shuffle.nil?

      num_workers = Rbbt.config(:num_workers, :dataloader, :default => 2)
      train_loader = RbbtPython.run :torch do
        dataset = features.zip(labels).collect{|f,l| [torch.tensor(f), l] }
        torch.utils.data.DataLoader.call(dataset, batch_size: batch_size, shuffle: shuffle, num_workers: num_workers.to_i)
      end
    end
    trainer.fit(model, train_loader, val_loader)
    TorchModel.save_architecture(model, model_path) if @directory
    TorchModel.save_state(model, model_path) if @directory
  end

  eval_model do |features,list=false|
    model = init
    eval_loader = self.loader
    if list
      if eval_loader.nil?
        batch_size ||= model_options[:batch_size]
        batch_size ||= model_options[:training_args][:batch_size]
        batch_size ||= 1

        num_workers = Rbbt.config(:num_workers, :dataloader, :default => 2)
        eval_loader = RbbtPython.run :torch do
          dataset = torch.tensor(features)
          torch.utils.data.DataLoader.call(dataset, batch_size: batch_size, num_workers: num_workers.to_i)
        end
      end
      trainer.predict(model, eval_loader).inject([]){|acc,res| acc.concat RbbtPython.numpy2ruby(res[1])}
    else
      model.call(torch.tensor(features))
    end
  end
end

Instance Attribute Details

#loaderObject

Returns the value of attribute loader.



4
5
6
# File 'lib/rbbt/vector/model/pytorch_lightning.rb', line 4

def loader
  @loader
end

#trainerObject

Returns the value of attribute trainer.



4
5
6
# File 'lib/rbbt/vector/model/pytorch_lightning.rb', line 4

def trainer
  @trainer
end

#val_loaderObject

Returns the value of attribute val_loader.



4
5
6
# File 'lib/rbbt/vector/model/pytorch_lightning.rb', line 4

def val_loader
  @val_loader
end