Class: Transformers::Bert::BertForTokenClassification

Inherits:
BertPreTrainedModel show all
Defined in:
lib/transformers/models/bert/modeling_bert.rb

Instance Attribute Summary

Attributes inherited from PreTrainedModel

#config

Instance Method Summary collapse

Methods inherited from BertPreTrainedModel

#_init_weights

Methods inherited from PreTrainedModel

#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask

Methods included from ClassAttribute

#class_attribute

Methods included from ModuleUtilsMixin

#device, #get_extended_attention_mask, #get_head_mask

Constructor Details

#initialize(config) ⇒ BertForTokenClassification

Returns a new instance of BertForTokenClassification.



768
769
770
771
772
773
774
775
776
777
778
779
780
781
# File 'lib/transformers/models/bert/modeling_bert.rb', line 768

def initialize(config)
  super(config)
  @num_labels = config.num_labels

  @bert = BertModel.new(config, add_pooling_layer: false)
  classifier_dropout = (
    !config.classifier_dropout.nil? ? config.classifier_dropout : config.hidden_dropout_prob
  )
  @dropout = Torch::NN::Dropout.new(p: classifier_dropout)
  @classifier = Torch::NN::Linear.new(config.hidden_size, config.num_labels)

  # Initialize weights and apply final processing
  post_init
end

Instance Method Details

#forward(input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object



783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
# File 'lib/transformers/models/bert/modeling_bert.rb', line 783

def forward(
  input_ids: nil,
  attention_mask: nil,
  token_type_ids: nil,
  position_ids: nil,
  head_mask: nil,
  inputs_embeds: nil,
  labels: nil,
  output_attentions: nil,
  output_hidden_states: nil,
  return_dict: nil
)
  return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict

  outputs = @bert.(
    input_ids: input_ids,
    attention_mask: attention_mask,
    token_type_ids: token_type_ids,
    position_ids: position_ids,
    head_mask: head_mask,
    inputs_embeds: inputs_embeds,
    output_attentions: output_attentions,
    output_hidden_states: output_hidden_states,
    return_dict: return_dict
  )

  sequence_output = outputs[0]

  sequence_output = @dropout.(sequence_output)
  logits = @classifier.(sequence_output)

  loss = nil
  if !labels.nil?
    loss_fct = CrossEntropyLoss.new
    loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1))
  end

  if !return_dict
    raise Todo
  end

  TokenClassifierOutput.new(
    loss: loss,
    logits: logits,
    hidden_states: outputs.hidden_states,
    attentions: outputs.attentions
  )
end