Class: TorchRec::Modules::DeepFM::FactorizationMachine

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/torchrec/modules/deepfm/factorization_machine.rb

Instance Method Summary collapse

Constructor Details

#initializeFactorizationMachine

Returns a new instance of FactorizationMachine.



5
6
7
# File 'lib/torchrec/modules/deepfm/factorization_machine.rb', line 5

def initialize
  super()
end

Instance Method Details

#forward(embeddings) ⇒ Object



9
10
11
12
13
14
15
16
17
# File 'lib/torchrec/modules/deepfm/factorization_machine.rb', line 9

def forward(embeddings)
  fm_input = flatten_input(embeddings)
  sum_of_input = Torch.sum(fm_input, dim: 1, keepdim: true)
  sum_of_square = Torch.sum(fm_input * fm_input, dim: 1, keepdim: true)
  square_of_sum = sum_of_input * sum_of_input
  cross_term = square_of_sum - sum_of_square
  cross_term = Torch.sum(cross_term, dim: 1, keepdim: true) * 0.5  # [B, 1]
  cross_term
end