Class: Transformers::DebertaV2::XDropout

Inherits:
Object
  • Object
show all
Defined in:
lib/transformers/models/deberta_v2/modeling_deberta_v2.rb

Overview

TODO Torch::Autograd::Function

Class Method Summary collapse

Class Method Details

.apply(input, local_ctx) ⇒ Object



90
91
92
93
94
95
96
97
98
99
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 90

def self.apply(input, local_ctx)
  mask, dropout = get_mask(input, local_ctx)
  @scale = 1.0 / (1 - dropout)
  if dropout > 0
    # ctx.save_for_backward(mask)
    input.masked_fill(mask, 0) * ctx.scale
  else
    input
  end
end