Class: Torch::NN::Embedding
Instance Method Summary collapse
- #forward(input) ⇒ Object
-
#initialize(num_embeddings, embedding_dim, padding_idx: nil, max_norm: nil, norm_type: 2.0, scale_grad_by_freq: false, sparse: false, _weight: nil) ⇒ Embedding
constructor
A new instance of Embedding.
- #inspect ⇒ Object
- #reset_parameters ⇒ Object
Methods inherited from Module
#_apply, #add_module, #apply, #buffers, #call, #children, #cpu, #cuda, #double, #eval, #float, #half, #load_state_dict, #method_missing, #modules, #named_buffers, #named_children, #named_modules, #named_parameters, #parameters, #register_buffer, #register_parameter, #requires_grad!, #respond_to?, #share_memory, #state_dict, #to, #train, #type, #zero_grad
Methods included from Utils
#_ntuple, #_pair, #_quadrupal, #_single, #_triple
Constructor Details
#initialize(num_embeddings, embedding_dim, padding_idx: nil, max_norm: nil, norm_type: 2.0, scale_grad_by_freq: false, sparse: false, _weight: nil) ⇒ Embedding
Returns a new instance of Embedding.
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
# File 'lib/torch/nn/embedding.rb', line 5 def initialize(, , padding_idx: nil, max_norm: nil, norm_type: 2.0, scale_grad_by_freq: false, sparse: false, _weight: nil) super() @num_embeddings = @embedding_dim = if padding_idx if padding_idx > 0 raise ArgumentError, "Padding_idx must be within num_embeddings" unless padding_idx < @num_embeddings elsif padding_idx < 0 raise ArgumentError, "Padding_idx must be within num_embeddings" unless padding_idx >= -@num_embeddings padding_idx = @num_embeddings + padding_idx end end @padding_idx = padding_idx @max_norm = max_norm @norm_type = norm_type @scale_grad_by_freq = scale_grad_by_freq if _weight.nil? @weight = Parameter.new(Tensor.new(, )) reset_parameters else raise ArgumentError, "Shape of weight does not match num_embeddings and embedding_dim" unless _weight.shape == [, ] @weight = Parameter.new(_weight) end @sparse = sparse end |
Dynamic Method Handling
This class handles dynamic methods through the method_missing method in the class Torch::NN::Module
Instance Method Details
#forward(input) ⇒ Object
43 44 45 |
# File 'lib/torch/nn/embedding.rb', line 43 def forward(input) F.(input, @weight, padding_idx: @padding_idx, max_norm: @max_norm, norm_type: @norm_type, scale_grad_by_freq: @scale_grad_by_freq, sparse: @sparse) end |
#inspect ⇒ Object
47 48 49 |
# File 'lib/torch/nn/embedding.rb', line 47 def inspect "Embedding(#{@num_embeddings}, #{@embedding_dim})" end |