Class: TorchRec::Modules::DeepFM::FactorizationMachine
- Inherits:
-
Torch::NN::Module
- Object
- Torch::NN::Module
- TorchRec::Modules::DeepFM::FactorizationMachine
- Defined in:
- lib/torchrec/modules/deepfm/factorization_machine.rb
Instance Method Summary collapse
- #forward(embeddings) ⇒ Object
-
#initialize ⇒ FactorizationMachine
constructor
A new instance of FactorizationMachine.
Constructor Details
#initialize ⇒ FactorizationMachine
Returns a new instance of FactorizationMachine.
5 6 7 |
# File 'lib/torchrec/modules/deepfm/factorization_machine.rb', line 5 def initialize super() end |
Instance Method Details
#forward(embeddings) ⇒ Object
9 10 11 12 13 14 15 16 17 |
# File 'lib/torchrec/modules/deepfm/factorization_machine.rb', line 9 def forward() fm_input = flatten_input() sum_of_input = Torch.sum(fm_input, dim: 1, keepdim: true) sum_of_square = Torch.sum(fm_input * fm_input, dim: 1, keepdim: true) square_of_sum = sum_of_input * sum_of_input cross_term = square_of_sum - sum_of_square cross_term = Torch.sum(cross_term, dim: 1, keepdim: true) * 0.5 # [B, 1] cross_term end |