Class: Transformers::DebertaV2::XSoftmax

Inherits:
Object
  • Object
show all
Defined in:
lib/transformers/models/deberta_v2/modeling_deberta_v2.rb

Overview

TODO Torch::Autograd::Function

Class Method Summary collapse

Class Method Details

.apply(input, mask, dim) ⇒ Object



43
44
45
46
47
48
49
50
51
52
53
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 43

def self.apply(input, mask, dim)
  @dim = dim
  rmask = mask.to(Torch.bool).bitwise_not

  # TODO use Torch.finfo
  output = input.masked_fill(rmask, Torch.tensor(-3.40282e+38))
  output = Torch.softmax(output, @dim)
  output.masked_fill!(rmask, 0)
  # ctx.save_for_backward(output)
  output
end