Class: Transformers::Mpnet::MPNetForMaskedLM
- Inherits:
-
MPNetPreTrainedModel
- Object
- Torch::NN::Module
- PreTrainedModel
- MPNetPreTrainedModel
- Transformers::Mpnet::MPNetForMaskedLM
- Defined in:
- lib/transformers/models/mpnet/modeling_mpnet.rb
Instance Attribute Summary
Attributes inherited from PreTrainedModel
Instance Method Summary collapse
- #forward(input_ids: nil, attention_mask: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
- #get_output_embeddings ⇒ Object
-
#initialize(config) ⇒ MPNetForMaskedLM
constructor
A new instance of MPNetForMaskedLM.
- #set_output_embeddings(new_embeddings) ⇒ Object
Methods inherited from MPNetPreTrainedModel
Methods inherited from PreTrainedModel
#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask
Methods included from ClassAttribute
Methods included from Transformers::ModuleUtilsMixin
#device, #get_extended_attention_mask, #get_head_mask
Constructor Details
#initialize(config) ⇒ MPNetForMaskedLM
Returns a new instance of MPNetForMaskedLM.
460 461 462 463 464 465 466 467 468 |
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 460 def initialize(config) super(config) @mpnet = MPNetModel.new(config, add_pooling_layer: false) @lm_head = MPNetLMHead.new(config) # Initialize weights and apply final processing post_init end |
Instance Method Details
#forward(input_ids: nil, attention_mask: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 |
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 479 def forward( input_ids: nil, attention_mask: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil ) return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict outputs = @mpnet.(input_ids, attention_mask: attention_mask, position_ids: position_ids, head_mask: head_mask, inputs_embeds: , output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict) sequence_output = outputs[0] prediction_scores = @lm_head.(sequence_output) masked_lm_loss = nil if !labels.nil? loss_fct = Torch::NN::CrossEntropyLoss.new masked_lm_loss = loss_fct.(prediction_scores.view(-1, @config.vocab_size), labels.view(-1)) end if !return_dict output = [prediction_scores] + outputs[2..] return !masked_lm_loss.nil? ? [masked_lm_loss] + output : output end MaskedLMOutput.new(loss: masked_lm_loss, logits: prediction_scores, hidden_states: outputs.hidden_states, attentions: outputs.attentions) end |
#get_output_embeddings ⇒ Object
470 471 472 |
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 470 def @lm_head.decoder end |
#set_output_embeddings(new_embeddings) ⇒ Object
474 475 476 477 |
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 474 def () @decoder = @bias = .bias end |