Class: Transformers::Mpnet::MPNetPooler

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/transformers/models/mpnet/modeling_mpnet.rb

Instance Method Summary collapse

Constructor Details

#initialize(config) ⇒ MPNetPooler

Returns a new instance of MPNetPooler.



367
368
369
370
371
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 367

def initialize(config)
  super()
  @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
  @activation = Torch::NN::Tanh.new
end

Instance Method Details

#forward(hidden_states) ⇒ Object



373
374
375
376
377
378
379
380
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 373

def forward(hidden_states)
  # We "pool" the model by simply taking the hidden state corresponding
  # to the first token.
  first_token_tensor = hidden_states[0.., 0]
  pooled_output = @dense.(first_token_tensor)
  pooled_output = @activation.(pooled_output)
  pooled_output
end