WaveNetConfig#

class tfts.models.wavenet.WaveNetConfig(dilation_rates: List[int] = None, kernel_sizes: List[int] = None, filters: int = 128, dense_hidden_size: int = 64, scheduled_sampling: float = 1.0, use_attention: bool = False, attention_size: int = 64, num_attention_heads: int = 2, attention_probs_dropout_prob: float = 0.0, **kwargs)[source]#

Bases: BaseConfig

Initializes the configuration for the WaveNet model with the specified parameters.

Parameters:
  • dilation_rates – List of dilation rates for the convolutional layers.

  • kernel_sizes – List of kernel sizes for the convolutional layers.

  • filters – The number of filters in the convolutional layers.

  • dense_hidden_size – The size of the dense hidden layer following the convolutional layers.

  • scheduled_sampling – Scheduled sampling ratio. 0 means teacher forcing, 1 means use last prediction

  • use_attention – Whether to use attention mechanism in the model.

  • attention_size – The size of the attention mechanism.

  • num_attention_heads – The number of attention heads.

  • attention_probs_dropout_prob – Dropout probability for attention probabilities.

Inherited-members:

Methods

from_dict(config_dict)

from_json(json_file)

from_pretrained(pretrained_model_name_or_path)

save_pretrained(save_directory)

to_dict()

to_json(json_file)

update(config_dict)

Attributes

attribute_map

model_type