Class: Transformers::Mpnet::MPNetPreTrainedModel

Inherits:
PreTrainedModel
  • Object
show all
Defined in:
lib/transformers/models/mpnet/modeling_mpnet.rb

Instance Attribute Summary

Attributes inherited from PreTrainedModel

#config

Instance Method Summary collapse

Methods inherited from PreTrainedModel

#_backward_compatibility_gradient_checkpointing, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #get_output_embeddings, #init_weights, #initialize, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask

Methods included from ClassAttribute

#class_attribute

Methods included from Transformers::ModuleUtilsMixin

#device, #get_extended_attention_mask, #get_head_mask

Constructor Details

This class inherits a constructor from Transformers::PreTrainedModel

Instance Method Details

#_init_weights(module_) ⇒ Object



22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 22

def _init_weights(module_)
  if module_.is_a?(Torch::NN::Linear)
    # Slightly different from the TF version which uses truncated_normal for initialization
    # cf https://github.com/pytorch/pytorch/pull/5617
    module_.weight.data.normal!(mean: 0.0, std: @config.initializer_range)
    if !module_.bias.nil?
      module_.bias.data.zero!
    end
  elsif module_.is_a?(Torch::NN::Embedding)
    module_.weight.data.normal!(mean: 0.0, std: @config.initializer_range)
    if !module_.padding_idx.nil?
      module_.weight.data.fetch(module_.padding_idx).zero!
    end
  elsif module_.is_a?(Torch::NN::LayerNorm)
    module_.bias.data.zero!
    module_.weight.data.fill!(1.0)
  end
end