Class: Transformers::DebertaV2::StableDropout
- Inherits:
-
Torch::NN::Module
- Object
- Torch::NN::Module
- Transformers::DebertaV2::StableDropout
- Defined in:
- lib/transformers/models/deberta_v2/modeling_deberta_v2.rb
Instance Method Summary collapse
- #clear_context ⇒ Object
- #forward(x) ⇒ Object
- #get_context ⇒ Object
- #init_context(reuse_mask: true, scale: 1) ⇒ Object
-
#initialize(drop_prob) ⇒ StableDropout
constructor
A new instance of StableDropout.
Constructor Details
#initialize(drop_prob) ⇒ StableDropout
Returns a new instance of StableDropout.
103 104 105 106 107 108 |
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 103 def initialize(drop_prob) super() @drop_prob = drop_prob @count = 0 @context_stack = nil end |
Instance Method Details
#clear_context ⇒ Object
117 118 119 120 |
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 117 def clear_context @count = 0 @context_stack = nil end |
#forward(x) ⇒ Object
110 111 112 113 114 115 |
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 110 def forward(x) if @training && @drop_prob > 0 return XDropout.apply(x, get_context) end x end |
#get_context ⇒ Object
133 134 135 136 137 138 139 140 141 142 143 144 145 |
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 133 def get_context if !@context_stack.nil? if @count >= @context_stack.length @context_stack << DropoutContext.new end ctx = @context_stack.fetch(@count) @dropout = @drop_prob @count += 1 ctx else @drop_prob end end |
#init_context(reuse_mask: true, scale: 1) ⇒ Object
122 123 124 125 126 127 128 129 130 131 |
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 122 def init_context(reuse_mask: true, scale: 1) if @context_stack.nil? @context_stack = [] end @count = 0 @context_stack.each do |c| @reuse_mask = reuse_mask @scale = scale end end |