Source code for tfts.models.tft

"""
`Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting
<https://arxiv.org/abs/1912.09363>`_
"""

from typing import Optional

import tensorflow as tf
from tensorflow.keras.layers import Dense, Dropout, LayerNormalization

from ..layers.attention_layer import Attention, SelfAttention
from ..layers.dense_layer import FeedForwardNetwork
from ..layers.embed_layer import DataEmbedding
from .base import BaseConfig, BaseModel


[docs] class TFTransformerConfig(BaseConfig): model_type: str = "tft" def __init__( self, hidden_size: int = 256, num_layers: int = 2, num_attention_heads: int = 4, attention_probs_dropout_prob: float = 0.0, hidden_dropout_prob: float = 0.0, ffn_intermediate_size: int = 256, max_position_embeddings: int = 512, initializer_range: float = 0.02, layer_norm_eps: float = 1e-12, pad_token_id: int = 0, **kwargs ): super(TFTransformerConfig, self).__init__() self.hidden_size = hidden_size self.num_layers = num_layers self.num_attention_heads = num_attention_heads self.attention_probs_dropout_prob = attention_probs_dropout_prob self.hidden_dropout_prob = hidden_dropout_prob self.ffn_intermediate_size = ffn_intermediate_size self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps self.pad_token_id = pad_token_id self.update(kwargs)
[docs] class TFTransformer(BaseModel): """Temporal fusion transformer model""" def __init__(self, predict_sequence_length=1, config: Optional[TFTransformerConfig] = None): super(TFTransformer, self).__init__() self.config = config or TFTransformerConfig() self.predict_sequence_length = predict_sequence_length # Embedding layers for temporal and static features self.temporal_embedding = DataEmbedding(self.config.hidden_size, positional_type="positional encoding") self.static_embedding = DataEmbedding(self.config.hidden_size) # Variable selection networks (simplified as dense layers with gating) self.temporal_variable_selection = Dense(self.config.hidden_size, activation="sigmoid") self.static_variable_selection = Dense(self.config.hidden_size, activation="sigmoid") # Gated Residual Networks (GRN) for feature processing self.temporal_grn = FeedForwardNetwork( self.config.hidden_size, self.config.ffn_intermediate_size, self.config.hidden_dropout_prob ) self.static_grn = FeedForwardNetwork( self.config.hidden_size, self.config.ffn_intermediate_size, self.config.hidden_dropout_prob ) # Static covariate encoder (using LSTM) self.static_encoder = tf.keras.layers.LSTM(self.config.hidden_size, return_sequences=True) # Temporal fusion decoder (combining LSTM, attention, and gating) self.temporal_decoder = tf.keras.layers.LSTM(self.config.hidden_size, return_sequences=True) self.attention = Attention( hidden_size=self.config.hidden_size, num_attention_heads=self.config.num_attention_heads, attention_probs_dropout_prob=self.config.attention_probs_dropout_prob, ) self.gate = Dense(self.config.hidden_size, activation="sigmoid") # Output projection self.output_projection = Dense(1) def __call__(self, x: tf.Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None): """Process inputs through the TFT model. Parameters ---------- x : tf.Tensor Input tensor of shape (batch_size, sequence_length, features). output_hidden_states : bool, optional Whether to output hidden states, by default None. return_dict : bool, optional Whether to return a dictionary of outputs, by default None. Returns ------- tf.Tensor Output tensor of shape (batch_size, predict_sequence_length, 1). """ # Prepare inputs x, encoder_feature, decoder_feature = self._prepare_3d_inputs(x, ignore_decoder_inputs=False) # Embed temporal and static features temporal_embedded = self.temporal_embedding(encoder_feature) static_embedded = self.static_embedding(decoder_feature) # Apply variable selection temporal_selected = self.temporal_variable_selection(temporal_embedded) static_selected = self.static_variable_selection(static_embedded) # Process through Gated Residual Networks temporal_processed = self.temporal_grn(temporal_selected) static_processed = self.static_grn(static_selected) # Encode static covariates static_encoded = self.static_encoder(static_processed) # Decode temporal features temporal_decoded = self.temporal_decoder(temporal_processed) # Apply attention and gating attention_output = self.attention(temporal_decoded, static_encoded, static_encoded) gate_output = self.gate(attention_output) fused_output = gate_output * attention_output # Project to output output = self.output_projection(fused_output) # Slice the output to only include the last predict_sequence_length steps output = output[:, -self.predict_sequence_length :, :] return output