Class: Transformers::Bert::BertForTokenClassification
- Inherits:
-
BertPreTrainedModel
- Object
- Torch::NN::Module
- PreTrainedModel
- BertPreTrainedModel
- Transformers::Bert::BertForTokenClassification
- Defined in:
- lib/transformers/models/bert/modeling_bert.rb
Instance Attribute Summary
Attributes inherited from PreTrainedModel
Instance Method Summary collapse
- #forward(input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
-
#initialize(config) ⇒ BertForTokenClassification
constructor
A new instance of BertForTokenClassification.
Methods inherited from BertPreTrainedModel
Methods inherited from PreTrainedModel
#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask
Methods included from ClassAttribute
Methods included from ModuleUtilsMixin
#device, #get_extended_attention_mask, #get_head_mask
Constructor Details
#initialize(config) ⇒ BertForTokenClassification
Returns a new instance of BertForTokenClassification.
768 769 770 771 772 773 774 775 776 777 778 779 780 781 |
# File 'lib/transformers/models/bert/modeling_bert.rb', line 768 def initialize(config) super(config) @num_labels = config.num_labels @bert = BertModel.new(config, add_pooling_layer: false) classifier_dropout = ( !config.classifier_dropout.nil? ? config.classifier_dropout : config.hidden_dropout_prob ) @dropout = Torch::NN::Dropout.new(p: classifier_dropout) @classifier = Torch::NN::Linear.new(config.hidden_size, config.num_labels) # Initialize weights and apply final processing post_init end |
Instance Method Details
#forward(input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 |
# File 'lib/transformers/models/bert/modeling_bert.rb', line 783 def forward( input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil ) return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict outputs = @bert.( input_ids: input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, head_mask: head_mask, inputs_embeds: , output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict ) sequence_output = outputs[0] sequence_output = @dropout.(sequence_output) logits = @classifier.(sequence_output) loss = nil if !labels.nil? loss_fct = CrossEntropyLoss.new loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1)) end if !return_dict raise Todo end TokenClassifierOutput.new( loss: loss, logits: logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions ) end |