TransformerConfig#
- class tfts.models.transformer.TransformerConfig(hidden_size: int = 256, num_layers: int = 2, num_decoder_layers: int = 4, num_attention_heads: int = 4, num_kv_heads: int = 4, ffn_intermediate_size: int = 256, hidden_act: str = 'gelu', hidden_dropout_prob: float = 0.0, attention_probs_dropout_prob: float = 0.0, scheduled_sampling: float = 1, max_position_embeddings: int = 512, initializer_range: float = 0.02, positional_type: str = 'positional encoding', use_cache: bool = True, classifier_dropout: float | None = None, layer_norm_eps: float = 1e-12, pad_token_id: int = 0, **kwargs: Dict[str, object])[source]#
Bases:
BaseConfigInitializes the configuration for the Transformer model with the specified parameters.
- Parameters:
hidden_size – The size of the hidden layers.
num_layers – The number of encoder layers.
num_decoder_layers – The number of decoder layers.
num_attention_heads – The number of attention heads.
num_kv_heads – The number of key-value heads.
ffn_intermediate_size – The size of the intermediate feed-forward layers.
hidden_act – The activation function for hidden layers.
hidden_dropout_prob – The dropout probability for hidden layers.
attention_probs_dropout_prob – The dropout probability for attention probabilities.
scheduled_sampling – Controls the use of teacher forcing vs. last prediction.
max_position_embeddings – The maximum length of input sequences.
initializer_range – The standard deviation for weight initialization.
layer_norm_eps – The epsilon for layer normalization.
pad_token_id – The ID for the padding token.
positional_type – The type of position embeddings (absolute or relative).
use_cache – Whether to use cache during inference.
classifier_dropout – Dropout rate for classifier layers.
**kwargs – Additional parameters for further customization passed to the parent class.
- Inherited-members:
Methods
from_dict(config_dict)from_json(json_file)from_pretrained(pretrained_model_name_or_path)save_pretrained(save_directory)to_dict()to_json(json_file)update(config_dict)Attributes
attribute_mapmodel_type