Source code for tfts.layers.rwkv_layer

from typing import Dict, Optional, Tuple

import tensorflow as tf


[docs] class TimeMixing(tf.keras.layers.Layer): """TensorFlow RWKV time mixing""" def __init__(self, config, **kwargs): super().__init__(**kwargs) self.n_embd = config.hidden_size def build(self, input_shape: Tuple[Optional[int], ...]): super().build(input_shape) # Trainable parameters self.time_mix_k = self.add_weight(name="time_mix_k", shape=(1, self.n_embd), initializer="zeros") self.time_mix_v = self.add_weight(name="time_mix_v", shape=(1, self.n_embd), initializer="zeros") self.time_mix_r = self.add_weight(name="time_mix_r", shape=(1, self.n_embd), initializer="zeros") self.time_first = self.add_weight(name="time_first", shape=(1, self.n_embd), initializer="zeros") self.time_decay = self.add_weight(name="time_decay", shape=(1, self.n_embd), initializer="zeros") self.key = tf.keras.layers.Dense(self.n_embd, use_bias=False) self.value = tf.keras.layers.Dense(self.n_embd, use_bias=False) self.receptance = tf.keras.layers.Dense(self.n_embd, use_bias=False) self.output_layer = tf.keras.layers.Dense(self.n_embd, use_bias=False)
[docs] def call(self, x, state): """time mixing Parameters ---------- x : tf.Tensor The input tensor of shape (batch_size, seq_length, embed_dim). """ aa, bb, pp = state # Mix with previous timestep xk = x * self.time_mix_k + state[0] * (1 - self.time_mix_k) xv = x * self.time_mix_v + state[0] * (1 - self.time_mix_v) xr = x * self.time_mix_r + state[0] * (1 - self.time_mix_r) r = tf.sigmoid(self.receptance(xr)) k = self.key(xk) v = self.value(xv) ww = self.time_first + k qq = tf.maximum(pp, ww) e1 = tf.exp(pp - qq) e2 = tf.exp(ww - qq) a = e1 * aa + e2 * v b = e1 * bb + e2 wkv = a / b # Update states ww = pp + self.time_decay qq = tf.maximum(ww, k) e1 = tf.exp(ww - qq) e2 = tf.exp(k - qq) new_aa = e1 * aa + e2 * v new_bb = e1 * bb + e2 new_pp = qq new_state = [new_aa, new_bb, new_pp] return self.output_layer(r * wkv), new_state
[docs] class ChannelMixing(tf.keras.layers.Layer): """TensorFlow RWKV channel mixing""" def __init__(self, config, **kwargs): super().__init__(**kwargs) self.n_embd = config.hidden_size def build(self, input_shape: Tuple[Optional[int], ...]): super().build(input_shape) self.time_mix_k = self.add_weight(name="time_mix_k", shape=(1, self.n_embd), initializer="zeros") self.time_mix_r = self.add_weight(name="time_mix_r", shape=(1, self.n_embd), initializer="zeros") self.key = tf.keras.layers.Dense(self.n_embd, use_bias=False) self.value = tf.keras.layers.Dense(self.n_embd, use_bias=False) self.receptance = tf.keras.layers.Dense(self.n_embd, use_bias=False)
[docs] def call(self, x, state): """channel mixing Parameters ---------- x : tf.Tensor The input tensor of shape (batch_size, seq_length, embed_dim). """ xk = x * self.time_mix_k + state * (1 - self.time_mix_k) xr = x * self.time_mix_r + state * (1 - self.time_mix_r) r = tf.sigmoid(self.receptance(xr)) k = tf.square(tf.nn.relu(self.key(xk))) return r * self.value(k), x