Class: Transformers::DebertaV2::XSoftmax
- Inherits:
-
Object
- Object
- Transformers::DebertaV2::XSoftmax
- Defined in:
- lib/transformers/models/deberta_v2/modeling_deberta_v2.rb
Overview
TODO Torch::Autograd::Function
Class Method Summary collapse
Class Method Details
.apply(input, mask, dim) ⇒ Object
43 44 45 46 47 48 49 50 51 52 53 |
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 43 def self.apply(input, mask, dim) @dim = dim rmask = mask.to(Torch.bool).bitwise_not # TODO use Torch.finfo output = input.masked_fill(rmask, Torch.tensor(-3.40282e+38)) output = Torch.softmax(output, @dim) output.masked_fill!(rmask, 0) # ctx.save_for_backward(output) output end |