Class: Transformers::DebertaV2::DebertaV2PredictionHeadTransform

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/transformers/models/deberta_v2/modeling_deberta_v2.rb

Instance Method Summary collapse

Constructor Details

#initialize(config) ⇒ DebertaV2PredictionHeadTransform

Returns a new instance of DebertaV2PredictionHeadTransform.



869
870
871
872
873
874
875
876
877
878
879
880
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 869

def initialize(config)
  super()
  @embedding_size = config.getattr("embedding_size", config.hidden_size)

  @dense = Torch::NN::Linear.new(config.hidden_size, @embedding_size)
  if config.hidden_act.is_a?(String)
    @transform_act_fn = ACT2FN[config.hidden_act]
  else
    @transform_act_fn = config.hidden_act
  end
  @LayerNorm = Torch::NN::LayerNorm.new(@embedding_size, eps: config.layer_norm_eps)
end

Instance Method Details

#forward(hidden_states) ⇒ Object



882
883
884
885
886
887
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 882

def forward(hidden_states)
  hidden_states = @dense.(hidden_states)
  hidden_states = @transform_act_fn.(hidden_states)
  hidden_states = @LayerNorm.(hidden_states)
  hidden_states
end