Class: Transformers::Mpnet::MPNetLMHead

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/transformers/models/mpnet/modeling_mpnet.rb

Instance Method Summary collapse

Constructor Details

#initialize(config) ⇒ MPNetLMHead

Returns a new instance of MPNetLMHead.



513
514
515
516
517
518
519
520
521
522
523
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 513

def initialize(config)
  super()
  @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
  @layer_norm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)

  @decoder = Torch::NN::Linear.new(config.hidden_size, config.vocab_size, bias: false)
  @bias = Torch::NN::Parameter.new(Torch.zeros(config.vocab_size))

  # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
  @bias = @bias
end

Instance Method Details

#_tie_weightsObject



525
526
527
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 525

def _tie_weights
  @bias = @bias
end

#forward(features, **kwargs) ⇒ Object



529
530
531
532
533
534
535
536
537
538
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 529

def forward(features, **kwargs)
  x = @dense.(features)
  x = Activations.gelu(x)
  x = @layer_norm.(x)

  # project back to size of vocabulary with bias
  x = @decoder.(x)

  x
end