Source code for tfts.models.deep_ar
"""
`DeepAR: Probabilistic Forecasting with Autoregressive Recurrent Networks
<https://arxiv.org/abs/1704.04110>`_
"""
from typing import Optional
import tensorflow as tf
from tensorflow.keras.layers import Activation, BatchNormalization, Dense
from tfts.tasks.auto_task import GaussianHead
from .base import BaseConfig, BaseModel
[docs]
class DeepARConfig(BaseConfig):
model_type: str = "deep_ar"
def __init__(
self,
rnn_hidden_size: int = 64,
):
super().__init__()
self.rnn_hidden_size = rnn_hidden_size
[docs]
class DeepAR(BaseModel):
"""DeepAR Network"""
def __init__(self, predict_sequence_length: int = 1, config: Optional[DeepARConfig] = None) -> None:
super(DeepAR, self).__init__()
self.config = config or DeepARConfig()
self.predict_sequence_length = predict_sequence_length
cell = tf.keras.layers.GRUCell(units=self.config.rnn_hidden_size)
self.rnn = tf.keras.layers.RNN(cell, return_state=True, return_sequences=True)
self.bn = BatchNormalization()
self.dense = Dense(units=predict_sequence_length, activation="relu")
self.gauss = GaussianHead(units=1)
def __call__(
self, inputs: tf.Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
):
"""DeepAR
Parameters
----------
inputs : tf.Tensor
3D input tensor for time series
Returns
-------
distribution of prediction
_description_
"""
x, _ = self.rnn(inputs)
x = self.dense(x)
loc, scale = self.gauss(x)
return loc, scale