Class: Transformers::Distilbert::DistilBertPreTrainedModel
- Inherits:
-
PreTrainedModel
- Object
- Torch::NN::Module
- PreTrainedModel
- Transformers::Distilbert::DistilBertPreTrainedModel
- Defined in:
- lib/transformers/models/distilbert/modeling_distilbert.rb
Direct Known Subclasses
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DistilBertForSequenceClassification, DistilBertModel
Instance Attribute Summary
Attributes inherited from PreTrainedModel
Instance Method Summary collapse
Methods inherited from PreTrainedModel
#_backward_compatibility_gradient_checkpointing, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #get_output_embeddings, #init_weights, #initialize, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask
Methods included from ClassAttribute
Methods included from ModuleUtilsMixin
#device, #get_extended_attention_mask, #get_head_mask
Constructor Details
This class inherits a constructor from Transformers::PreTrainedModel
Instance Method Details
#_init_weights(mod) ⇒ Object
315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 |
# File 'lib/transformers/models/distilbert/modeling_distilbert.rb', line 315 def _init_weights(mod) if mod.is_a?(Torch::NN::Linear) mod.weight.data.normal!(mean: 0.0, std: @config.initializer_range) if !mod.bias.nil? mod.bias.data.zero! end elsif mod.is_a?(Torch::NN::Embedding) mod.weight.data.normal!(mean: 0.0, std: @config.initializer_range) if !mod.instance_variable_get(:@padding_idx).nil? mod.weight.data[mod.instance_variable_get(:@padding_idx)].zero! end elsif mod.is_a?(Torch::NN::LayerNorm) mod.bias.data.zero! mod.weight.data.fill!(1.0) elsif mod.is_a?(Embeddings) && @config.sinusoidal_pos_embds ( @config., @config.dim, mod..weight ) end end |