WaveNetConfig#
- class tfts.models.wavenet.WaveNetConfig(dilation_rates: List[int] = None, kernel_sizes: List[int] = None, filters: int = 128, dense_hidden_size: int = 64, scheduled_sampling: float = 1.0, use_attention: bool = False, attention_size: int = 64, num_attention_heads: int = 2, attention_probs_dropout_prob: float = 0.0, **kwargs)[source]#
Bases:
BaseConfigInitializes the configuration for the WaveNet model with the specified parameters.
- Parameters:
dilation_rates – List of dilation rates for the convolutional layers.
kernel_sizes – List of kernel sizes for the convolutional layers.
filters – The number of filters in the convolutional layers.
dense_hidden_size – The size of the dense hidden layer following the convolutional layers.
scheduled_sampling – Scheduled sampling ratio. 0 means teacher forcing, 1 means use last prediction
use_attention – Whether to use attention mechanism in the model.
attention_size – The size of the attention mechanism.
num_attention_heads – The number of attention heads.
attention_probs_dropout_prob – Dropout probability for attention probabilities.
- Inherited-members:
Methods
from_dict(config_dict)from_json(json_file)from_pretrained(pretrained_model_name_or_path)save_pretrained(save_directory)to_dict()to_json(json_file)update(config_dict)Attributes
attribute_mapmodel_type