Class: Transformers::Mpnet::MPNetModel

Inherits:
MPNetPreTrainedModel show all
Defined in:
lib/transformers/models/mpnet/modeling_mpnet.rb

Instance Attribute Summary

Attributes inherited from PreTrainedModel

#config

Instance Method Summary collapse

Methods inherited from MPNetPreTrainedModel

#_init_weights

Methods inherited from PreTrainedModel

#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #tie_weights, #warn_if_padding_and_no_attention_mask

Methods included from ClassAttribute

#class_attribute

Methods included from Transformers::ModuleUtilsMixin

#device, #get_extended_attention_mask, #get_head_mask

Constructor Details

#initialize(config, add_pooling_layer: true) ⇒ MPNetModel

Returns a new instance of MPNetModel.



384
385
386
387
388
389
390
391
392
393
394
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 384

def initialize(config, add_pooling_layer: true)
  super(config)
  @config = config

  @embeddings = MPNetEmbeddings.new(config)
  @encoder = MPNetEncoder.new(config)
  @pooler = add_pooling_layer ? MPNetPooler.new(config) : nil

  # Initialize weights and apply final processing
  post_init
end

Instance Method Details

#_prune_heads(heads_to_prune) ⇒ Object



404
405
406
407
408
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 404

def _prune_heads(heads_to_prune)
  heads_to_prune.each do |layer, heads|
    @encoder.layer[layer].attention.prune_heads(heads)
  end
end

#forward(input_ids: nil, attention_mask: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil, **kwargs) ⇒ Object



410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 410

def forward(
  input_ids: nil,
  attention_mask: nil,
  position_ids: nil,
  head_mask: nil,
  inputs_embeds: nil,
  output_attentions: nil,
  output_hidden_states: nil,
  return_dict: nil,
  **kwargs
)
  output_attentions = !output_attentions.nil? ? output_attentions : @config.output_attentions
  output_hidden_states = !output_hidden_states.nil? ? output_hidden_states : @config.output_hidden_states
  return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict

  if !input_ids.nil? && !inputs_embeds.nil?
    raise ArgumentError, "You cannot specify both input_ids and inputs_embeds at the same time"
  elsif !input_ids.nil?
    warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
    input_shape = input_ids.size
  elsif !inputs_embeds.nil?
    input_shape = inputs_embeds.size[...-1]
  else
    raise ArgumentError, "You have to specify either input_ids or inputs_embeds"
  end

  device = !input_ids.nil? ? input_ids.device : inputs_embeds.device

  if attention_mask.nil?
    attention_mask = Torch.ones(input_shape, device: device)
  end
  extended_attention_mask = get_extended_attention_mask(attention_mask, input_shape)

  head_mask = get_head_mask(head_mask, @config.num_hidden_layers)
  embedding_output = @embeddings.(input_ids: input_ids, position_ids: position_ids, inputs_embeds: inputs_embeds)
  encoder_outputs = @encoder.(embedding_output, attention_mask: extended_attention_mask, head_mask: head_mask, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
  sequence_output = encoder_outputs[0]
  pooled_output = !@pooler.nil? ? @pooler.(sequence_output) : nil

  if !return_dict
    return [sequence_output, pooled_output] + encoder_outputs[1..]
  end

  BaseModelOutputWithPooling.new(last_hidden_state: sequence_output, pooler_output: pooled_output, hidden_states: encoder_outputs.hidden_states, attentions: encoder_outputs.attentions)
end

#get_input_embeddingsObject



396
397
398
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 396

def get_input_embeddings
  @embeddings.word_embeddings
end

#set_input_embeddings(value) ⇒ Object



400
401
402
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 400

def set_input_embeddings(value)
  @word_embeddings = value
end