RNNConfig#

class tfts.models.rnn.RNNConfig(rnn_hidden_size: int = 64, rnn_type: Literal['gru', 'lstm'] = 'gru', bi_direction: bool = False, dense_hidden_size: int = 128, num_stacked_layers: int = 1, scheduled_sampling: float = 0.0, use_attention: bool = False)[source]#

Bases: BaseConfig

Initializes the configuration for the RNN model with the specified parameters.

Parameters:
  • rnn_hidden_size – The number of units in the RNN hidden layer.

  • rnn_type – Type of RNN (‘gru’ or ‘lstm’).

  • bi_direction – Whether to use bidirectional RNN.

  • dense_hidden_size – The size of the dense hidden layer following the RNN.

  • num_stacked_layers – The number of stacked RNN layers.

  • scheduled_sampling – Scheduled sampling ratio.

  • use_attention – Whether to use attention mechanism.

Inherited-members:

Methods

from_dict(config_dict)

from_json(json_file)

from_pretrained(pretrained_model_name_or_path)

save_pretrained(save_directory)

to_dict()

to_json(json_file)

update(config_dict)

Attributes

attribute_map

model_type