Class: Transformers::DebertaV2::StableDropout

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/transformers/models/deberta_v2/modeling_deberta_v2.rb

Instance Method Summary collapse

Constructor Details

#initialize(drop_prob) ⇒ StableDropout

Returns a new instance of StableDropout.



103
104
105
106
107
108
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 103

def initialize(drop_prob)
  super()
  @drop_prob = drop_prob
  @count = 0
  @context_stack = nil
end

Instance Method Details

#clear_contextObject



117
118
119
120
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 117

def clear_context
  @count = 0
  @context_stack = nil
end

#forward(x) ⇒ Object



110
111
112
113
114
115
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 110

def forward(x)
  if @training && @drop_prob > 0
    return XDropout.apply(x, get_context)
  end
  x
end

#get_contextObject



133
134
135
136
137
138
139
140
141
142
143
144
145
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 133

def get_context
  if !@context_stack.nil?
    if @count >= @context_stack.length
      @context_stack << DropoutContext.new
    end
    ctx = @context_stack.fetch(@count)
    @dropout = @drop_prob
    @count += 1
    ctx
  else
    @drop_prob
  end
end

#init_context(reuse_mask: true, scale: 1) ⇒ Object



122
123
124
125
126
127
128
129
130
131
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 122

def init_context(reuse_mask: true, scale: 1)
  if @context_stack.nil?
    @context_stack = []
  end
  @count = 0
  @context_stack.each do |c|
    @reuse_mask = reuse_mask
    @scale = scale
  end
end