RNNConfig#
- class tfts.models.rnn.RNNConfig(rnn_hidden_size: int = 64, rnn_type: Literal['gru', 'lstm'] = 'gru', bi_direction: bool = False, dense_hidden_size: int = 128, num_stacked_layers: int = 1, scheduled_sampling: float = 0.0, use_attention: bool = False)[source]#
Bases:
BaseConfigInitializes the configuration for the RNN model with the specified parameters.
- Parameters:
rnn_hidden_size – The number of units in the RNN hidden layer.
rnn_type – Type of RNN (‘gru’ or ‘lstm’).
bi_direction – Whether to use bidirectional RNN.
dense_hidden_size – The size of the dense hidden layer following the RNN.
num_stacked_layers – The number of stacked RNN layers.
scheduled_sampling – Scheduled sampling ratio.
use_attention – Whether to use attention mechanism.
- Inherited-members:
Methods
from_dict(config_dict)from_json(json_file)from_pretrained(pretrained_model_name_or_path)save_pretrained(save_directory)to_dict()to_json(json_file)update(config_dict)Attributes
attribute_mapmodel_type