Class: Transformers::Mpnet::MPNetForMaskedLM

Inherits:
MPNetPreTrainedModel show all
Defined in:
lib/transformers/models/mpnet/modeling_mpnet.rb

Instance Attribute Summary

Attributes inherited from PreTrainedModel

#config

Instance Method Summary collapse

Methods inherited from MPNetPreTrainedModel

#_init_weights

Methods inherited from PreTrainedModel

#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask

Methods included from ClassAttribute

#class_attribute

Methods included from Transformers::ModuleUtilsMixin

#device, #get_extended_attention_mask, #get_head_mask

Constructor Details

#initialize(config) ⇒ MPNetForMaskedLM

Returns a new instance of MPNetForMaskedLM.



460
461
462
463
464
465
466
467
468
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 460

def initialize(config)
  super(config)

  @mpnet = MPNetModel.new(config, add_pooling_layer: false)
  @lm_head = MPNetLMHead.new(config)

  # Initialize weights and apply final processing
  post_init
end

Instance Method Details

#forward(input_ids: nil, attention_mask: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object



479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 479

def forward(
  input_ids: nil,
  attention_mask: nil,
  position_ids: nil,
  head_mask: nil,
  inputs_embeds: nil,
  labels: nil,
  output_attentions: nil,
  output_hidden_states: nil,
  return_dict: nil
)
  return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict

  outputs = @mpnet.(input_ids, attention_mask: attention_mask, position_ids: position_ids, head_mask: head_mask, inputs_embeds: inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)

  sequence_output = outputs[0]
  prediction_scores = @lm_head.(sequence_output)

  masked_lm_loss = nil
  if !labels.nil?
    loss_fct = Torch::NN::CrossEntropyLoss.new
    masked_lm_loss = loss_fct.(prediction_scores.view(-1, @config.vocab_size), labels.view(-1))
  end

  if !return_dict
    output = [prediction_scores] + outputs[2..]
    return !masked_lm_loss.nil? ? [masked_lm_loss] + output : output
  end

  MaskedLMOutput.new(loss: masked_lm_loss, logits: prediction_scores, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
end

#get_output_embeddingsObject



470
471
472
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 470

def get_output_embeddings
  @lm_head.decoder
end

#set_output_embeddings(new_embeddings) ⇒ Object



474
475
476
477
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 474

def set_output_embeddings(new_embeddings)
  @decoder = new_embeddings
  @bias = new_embeddings.bias
end