Class: Transformers::XlmRoberta::XLMRobertaPreTrainedModel
- Inherits:
-
PreTrainedModel
- Object
- Torch::NN::Module
- PreTrainedModel
- Transformers::XlmRoberta::XLMRobertaPreTrainedModel
- Defined in:
- lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb
Direct Known Subclasses
XLMRobertaForCausalLM, XLMRobertaForMaskedLM, XLMRobertaForMultipleChoice, XLMRobertaForQuestionAnswering, XLMRobertaForSequenceClassification, XLMRobertaForTokenClassification, XLMRobertaModel
Instance Attribute Summary
Attributes inherited from PreTrainedModel
Instance Method Summary collapse
-
#_init_weights(module_) ⇒ Object
self.supports_gradient_checkpointing = true self._no_split_modules = [“XLMRobertaEmbeddings”, “XLMRobertaSelfAttention”, “XLMRobertaSdpaSelfAttention”] self._supports_sdpa = true.
Methods inherited from PreTrainedModel
#_backward_compatibility_gradient_checkpointing, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #get_output_embeddings, #init_weights, #initialize, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask
Methods included from ClassAttribute
Methods included from ModuleUtilsMixin
#device, #get_extended_attention_mask, #get_head_mask
Constructor Details
This class inherits a constructor from Transformers::PreTrainedModel
Instance Method Details
#_init_weights(module_) ⇒ Object
self.supports_gradient_checkpointing = true self._no_split_modules = [“XLMRobertaEmbeddings”, “XLMRobertaSelfAttention”, “XLMRobertaSdpaSelfAttention”] self._supports_sdpa = true
586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 |
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 586 def _init_weights(module_) if module_.is_a?(Torch::NN::Linear) # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module_.weight.data.normal!(mean: 0.0, std: @config.initializer_range) if !module_.bias.nil? module_.bias.data.zero! end elsif module_.is_a?(Torch::NN::Embedding) module_.weight.data.normal!(mean: 0.0, std: @config.initializer_range) if !module_.padding_idx.nil? module_.weight.data.fetch(module_.padding_idx).zero! end elsif module_.is_a?(Torch::NN::LayerNorm) module_.bias.data.zero! module_.weight.data.fill!(1.0) end end |