Source code for tfts.models.rwkv

"""
`RWKV: Reinventing RNNs for the Transformer Era
<https://arxiv.org/abs/2305.13048>`_
"""

from typing import Dict, Optional

import tensorflow as tf
from tensorflow.keras.layers import GRU, Dense

from tfts.layers.embed_layer import DataEmbedding
from tfts.layers.rwkv_layer import ChannelMixing, TimeMixing

from .base import BaseConfig, BaseModel


[docs] class RWKVConfig(BaseConfig): model_type: str = "rwkv" def __init__( self, num_layers: int = 25, hidden_size: int = 64, dense_hidden_size: int = 32, dropout: float = 0.0, max_position_embeddings: int = 512, initializer_range: float = 0.02, layer_norm_eps: float = 1e-12, pad_token_id: int = 0, **kwargs ) -> None: """ Initializes the configuration for the RWKV model with the specified parameters. Args: num_layers: The number of stacked RWKV layers. hidden_size: Size of each attention head. dense_hidden_size: The size of the dense hidden layer following the RWKV. dropout: Dropout rate for regularization. max_position_embeddings: Maximum sequence length for positional embeddings. initializer_range: Standard deviation for weight initialization. layer_norm_eps: Epsilon for layer normalization. pad_token_id: ID for padding token. """ super().__init__() self.num_layers: int = num_layers self.hidden_size: int = hidden_size self.dense_hidden_size: int = dense_hidden_size self.dropout: float = dropout self.max_position_embeddings: int = max_position_embeddings self.initializer_range: float = initializer_range self.layer_norm_eps: float = layer_norm_eps self.pad_token_id: int = pad_token_id self.update(kwargs)
[docs] class RWKV(BaseModel): """TensorFlow RWKV model for time series forecasting""" def __init__(self, predict_sequence_length: int = 1, config: Optional[RWKVConfig] = None): super().__init__() self.config = config or RWKVConfig() self.predict_sequence_length = predict_sequence_length # Embedding layer self.embedding = DataEmbedding(self.config.hidden_size, positional_type="positional encoding") self.ln0 = tf.keras.layers.LayerNormalization(epsilon=self.config.layer_norm_eps) self.blocks = [RWKVBlock(self.config) for _ in range(self.config.num_layers)] self.ln_out = tf.keras.layers.LayerNormalization(epsilon=self.config.layer_norm_eps) self.output_projection = Dense(1) def init_state(self, batch_size=1): states = [] for _ in range(self.config.num_layers): # States for attention att_states = [ tf.zeros((batch_size, self.config.hidden_size)), # aa tf.zeros((batch_size, self.config.hidden_size)), # bb tf.zeros((batch_size, self.config.hidden_size)) - 1e30, # pp ] # State for FFN ffn_state = tf.zeros((batch_size, self.config.hidden_size)) states.append((att_states, ffn_state)) return states def __call__( self, x, states=None, teacher=None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): """RWKV model call for time series forecasting""" if states is None: states = self.init_state() # Prepare inputs x, encoder_feature, decoder_feature = self._prepare_3d_inputs(x, ignore_decoder_inputs=False) # Embed inputs x = self.embedding(encoder_feature) x = self.ln0(x) new_states = [] for i, block in enumerate(self.blocks): x, block_states = block(x, states[i]) new_states.append(block_states) x = self.ln_out(x) x = self.output_projection(x) # Slice the output to only include the last predict_sequence_length steps x = x[:, -self.predict_sequence_length :, :] return x, new_states def _prepare_inputs(self, inputs): """Prepare the inputs for the encoder.""" if isinstance(inputs, (list, tuple)): x, encoder_feature, _ = inputs encoder_feature = tf.concat([x, encoder_feature], axis=-1) elif isinstance(inputs, dict): x = inputs["x"] encoder_feature = inputs["encoder_feature"] encoder_feature = tf.concat([x, encoder_feature], axis=-1) else: x = inputs encoder_feature = x return x, encoder_feature def _prepare_3d_inputs(self, x, ignore_decoder_inputs=True): """Prepare the inputs for the encoder and decoder.""" if isinstance(x, (list, tuple)): encoder_feature = x[1] if len(x) > 1 else x[0] decoder_feature = x[2] if len(x) > 2 else None elif isinstance(x, dict): encoder_feature = x["encoder_feature"] decoder_feature = x["decoder_feature"] if "decoder_feature" in x else None else: encoder_feature = x decoder_feature = None if ignore_decoder_inputs: return encoder_feature, encoder_feature, None else: return encoder_feature, encoder_feature, decoder_feature
[docs] class RWKVBlock(tf.keras.layers.Layer): """TensorFlow RWKV block""" def __init__(self, config, **kwargs): super().__init__(**kwargs) self.config = config self.ln1 = tf.keras.layers.LayerNormalization() self.attention = TimeMixing(config) self.ln2 = tf.keras.layers.LayerNormalization() self.feed_forward = ChannelMixing(config) def build(self, input_shape): super().build(input_shape) # Ensure the attention and feed_forward layers are built self.attention.build(input_shape) self.feed_forward.build(input_shape)
[docs] def call(self, x, states): """block Parameters ---------- x : tf.Tensor The input tensor of shape (batch_size, seq_length, embed_dim). """ att_states, ffn_state = states # Attention h, new_att_states = self.attention(self.ln1(x), att_states) x = x + h # Feed-forward h, new_ffn_state = self.feed_forward(self.ln2(x), ffn_state) x = x + h return x, (new_att_states, new_ffn_state)