"""Layer for :py:class:`~tfts.models.transformer` :py:class:`~tfts.models.autoformer`"""
from typing import Any, Dict, Optional, Tuple
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Dropout
from tfts.layers.mask_layer import ProbMask
[docs]
class Attention(tf.keras.layers.Layer):
"""Multi-head attention layer"""
def __init__(
self, hidden_size: int, num_attention_heads: int = 1, attention_probs_dropout_prob: float = 0.0, **kwargs
) -> None:
"""Initialize the Attention layer.
Parameters:
-----------
hidden_size : int
The number of hidden units, hidden_size = attention_dim_each_head x num_attention_heads.
num_attention_heads : int
The number of attention heads.
attention_probs_dropout_prob : float, optional
Dropout rate for the attention weights. Defaults to 0.0.
"""
super(Attention, self).__init__(**kwargs)
if hidden_size % num_attention_heads != 0:
raise ValueError(
f"Hidden size {hidden_size} must be divisible by the number of heads {num_attention_heads}."
)
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.attention_probs_dropout_prob = attention_probs_dropout_prob
def build(self, input_shape: Tuple[Optional[int], ...]) -> None:
self.dense_q = Dense(self.hidden_size, use_bias=False)
self.dense_k = Dense(self.hidden_size, use_bias=False)
self.dense_v = Dense(self.hidden_size, use_bias=False)
self.dropout = Dropout(rate=self.attention_probs_dropout_prob)
super(Attention, self).build(input_shape)
[docs]
def call(
self,
q: tf.Tensor,
k: tf.Tensor,
v: tf.Tensor,
mask: Optional[tf.Tensor] = None,
past_key_value=None,
training: Optional[bool] = None,
return_attention_scores: bool = False,
use_causal_mask: bool = False,
**kwargs,
):
"""use query and key generating an attention multiplier for value, multi_heads to repeat it
Parameters
----------
q : tf.Tenor
Query with shape batch * seq_q * fea
k : tf.Tensor
Key with shape batch * seq_k * fea
v : tf.Tensor
Value with shape batch * seq_v * fea
mask :tf.Tensor, optional
important to avoid the leaks, by default None
Returns
-------
tf.Tensor
Tensor with shape batch * seq_q * (units * num_attention_heads)
"""
# project the query/key/value to num_attention_heads * units
q = self.dense_q(q)
k = self.dense_k(k)
v = self.dense_v(v)
# multi-heads transfer to multi-sample
q_ = tf.concat(tf.split(q, self.num_attention_heads, axis=2), axis=0)
k_ = tf.concat(tf.split(k, self.num_attention_heads, axis=2), axis=0)
v_ = tf.concat(tf.split(v, self.num_attention_heads, axis=2), axis=0)
# => (batch * heads) * seq_q * seq_k
score = tf.linalg.matmul(q_, k_, transpose_b=True)
score = score / tf.cast(tf.shape(q_)[-1], score.dtype) ** 0.5
if mask is not None:
mask = tf.repeat(mask, repeats=self.num_attention_heads, axis=0)
score = score * tf.cast(mask, score.dtype)
score = tf.nn.softmax(score, axis=-1)
score = self.dropout(score, training=training)
# (batch * heads) * seq_q * units
outputs = tf.linalg.matmul(score, v_)
outputs = tf.concat(tf.split(outputs, self.num_attention_heads, axis=0), axis=2)
if return_attention_scores:
return outputs, score
return outputs
[docs]
def get_config(self):
config = {
"hidden_size": self.hidden_size,
"num_attention_heads": self.num_attention_heads,
"attention_probs_dropout_prob": self.attention_probs_dropout_prob,
}
base_config = super(Attention, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def compute_output_shape(self, input_shape):
if isinstance(input_shape, tuple) and len(input_shape) == 3:
batch_size, seq_len, _ = input_shape
return (batch_size, seq_len, self.hidden_size)
elif isinstance(input_shape, (list, tuple)) and len(input_shape) == 3:
q_shape, k_shape, v_shape = input_shape
# Validate that all shapes are tuples with 3 dimensions
if not all(isinstance(shape, tuple) and len(shape) == 3 for shape in [q_shape, k_shape, v_shape]):
raise ValueError(
"Each input shape must be a tuple of length 3 (batch_size, seq_len, features). "
f"Got shapes: q={q_shape}, k={k_shape}, v={v_shape}"
)
# Output shape is based on query sequence length
batch_size, seq_q_len, _ = q_shape
return (batch_size, seq_q_len, self.hidden_size)
else:
raise ValueError(
"Expected input_shape to be either:\n"
"1. A single tuple (batch_size, seq_len, features) for self-attention, or\n"
"2. A list/tuple of 3 shapes [(q_shape), (k_shape), (v_shape)] for cross-attention.\n"
f"Got: {input_shape}"
)
[docs]
class SelfAttention(tf.keras.layers.Layer):
def __init__(
self,
hidden_size: int,
num_attention_heads: int = 1,
attention_probs_dropout_prob: float = 0.0,
**kwargs: Dict[str, Any],
) -> None:
super(SelfAttention, self).__init__(**kwargs)
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.attention_probs_dropout_prob = attention_probs_dropout_prob
def build(self, input_shape: Tuple[Optional[int], ...]) -> None:
self.attention = Attention(
self.hidden_size,
self.num_attention_heads,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
)
super(SelfAttention, self).build(input_shape)
[docs]
def call(self, x: tf.Tensor, mask: Optional[tf.Tensor] = None, training: Optional[bool] = None):
"""Self attention layer
Parameters
----------
x : tf.Tensor
3D input tensor for self-attention, (batch_size, sequence_length, feature_size)
mask : tf.Tensor, optional
masked, by default None
Returns
-------
tf.Tensor
3D self attention output, (batch_size, sequence_length, attention_hidden_size)
"""
return self.attention(q=x, k=x, v=x, mask=mask, training=training)
[docs]
def get_config(self):
base_config = super(SelfAttention, self).get_config()
return base_config
def compute_output_shape(self, input_shape):
return (input_shape[0], input_shape[1], self.hidden_size)
[docs]
class ProbAttention(tf.keras.layers.Layer):
def __init__(
self, hidden_size: int = 128, num_attention_heads: int = 1, attention_probs_dropout_prob: float = 0.0, **kwargs
):
super().__init__(**kwargs)
self.mask_flag = True
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.factor = 5
self.scale = None
def build(self, input_shape: Tuple[Optional[int], ...]) -> None:
self.dense_q = Dense(self.hidden_size, use_bias=False)
self.dense_k = Dense(self.hidden_size, use_bias=False)
self.dense_v = Dense(self.hidden_size, use_bias=False)
super().build(input_shape)
def _prob_qk(self, q, k, sample_k, top_n):
_, H, L, E = k.shape
_, _, S, _ = q.shape
B = tf.shape(k)[0]
k_expand = tf.broadcast_to(tf.expand_dims(k, -3), (B, H, L, S, E))
indx_q_seq = tf.random.uniform((S,), maxval=L, dtype=tf.int32)
indx_k_seq = tf.random.uniform((sample_k,), maxval=L, dtype=tf.int32)
K_sample = tf.gather(k_expand, tf.range(S), axis=2)
K_sample = tf.gather(K_sample, indx_q_seq, axis=2)
K_sample = tf.gather(K_sample, indx_k_seq, axis=3)
Q_K_sample = tf.squeeze(tf.matmul(tf.expand_dims(q, -2), tf.einsum("...ij->...ji", K_sample)), axis=3)
M = tf.math.reduce_max(Q_K_sample, axis=-1) - tf.raw_ops.Div(x=tf.reduce_sum(Q_K_sample, axis=-1), y=L)
m_top = tf.math.top_k(M, top_n, sorted=False)[1]
batch_indexes = tf.tile(tf.range(B)[:, tf.newaxis, tf.newaxis], (1, H, top_n))
head_indexes = tf.tile(tf.range(H)[tf.newaxis, :, tf.newaxis], (B, 1, top_n))
idx = tf.stack([batch_indexes, head_indexes, m_top], axis=-1)
q_reduce = tf.gather_nd(q, idx)
qk = tf.matmul(q_reduce, tf.transpose(k, (0, 1, 3, 2)))
return qk, m_top
def _get_initial_context(self, v, L_Q):
_, H, L_V, D = v.shape
B = tf.shape(v)[0]
if not self.mask_flag:
v_sum = tf.math.reduce_sum(v, axis=-2)
context = tf.identity(tf.boradcast_to(tf.expand_dims(v_sum, -2), [B, H, L_Q, v_sum.shape[-1]]))
else:
assert L_Q == L_V
context = tf.math.cumsum(v, axis=-2)
return context
def _update_context(self, context_in, v, scores, index, L_Q):
_, H, L_V, D = v.shape
B = tf.shape(v)[0]
batch_indexes = tf.tile(tf.range(B)[:, tf.newaxis, tf.newaxis], (1, H, tf.shape(index)[-1]))
head_indexes = tf.tile(tf.range(H)[tf.newaxis, :, tf.newaxis], (B, 1, tf.shape(index)[-1]))
index = tf.stack([batch_indexes, head_indexes, index], axis=-1)
if self.mask_flag:
attn_mask = ProbMask(B, H, L_Q, index, scores).mask
scores = tf.where(attn_mask, -np.inf, scores)
attn = tf.nn.softmax(scores, axis=-1)
context_in = tf.tensor_scatter_nd_update(context_in, index, tf.matmul(attn, v))
return tf.convert_to_tensor(context_in)
# @tf.function
[docs]
def call(self, q, k, v, mask: Optional[tf.Tensor] = None):
"""Prob attention"""
q = self.dense_q(q) # project the query/key/value to num_attention_heads * units
k = self.dense_k(k)
v = self.dense_v(v)
_, L, D = q.shape
B = tf.shape(q)[0]
_, S, _ = k.shape
q_ = tf.reshape(q, (-1, self.num_attention_heads, L, self.hidden_size // self.num_attention_heads))
k_ = tf.reshape(k, (-1, self.num_attention_heads, S, self.hidden_size // self.num_attention_heads))
v_ = tf.reshape(v, (-1, self.num_attention_heads, S, self.hidden_size // self.num_attention_heads))
u_q = self.factor * np.ceil(np.log(L)).astype("int").item()
u_k = self.factor * np.ceil(np.log(S)).astype("int").item()
u_q = u_q if u_q < L else L
u_k = u_k if u_k < S else S
scores_top, index = self._prob_qk(q_, k_, u_k, u_q)
scores_top = scores_top * 1.0 / np.sqrt(D // self.num_attention_heads)
context = self._get_initial_context(v_, L)
context = self._update_context(context, v_, scores_top, index, L)
out = tf.reshape(context, (B, L, -1))
return out
[docs]
def get_config(self):
config = {
"hidden_size": self.hidden_size,
"num_attention_heads": self.num_attention_heads,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def compute_output_shape(self, input_shape):
batch_size = input_shape[0]
sequence_length = input_shape[1]
return (batch_size, sequence_length, self.hidden_size)
[docs]
class SparseAttention(tf.keras.layers.Layer):
"""
SparseAttention implementation
"""
def __init__(self, hidden_size: int, num_attention_heads: int, attention_probs_dropout_prob: float = 0.0, **kwargs):
super().__init__(**kwargs)
self.hidden_size = hidden_size
def build(self, input_shape: Tuple[Optional[int], ...]):
super().build(input_shape)
[docs]
def call(self, x, mask: Optional[tf.Tensor] = None):
"""Sparse attention
Parameters
----------
x : tf.Tensor
_description_
mask : tf.Tensor, optional
_description_, by default None
"""
return
[docs]
def get_config(self):
base_config = super().get_config()
return base_config
def compute_output_shape(self, input_shape):
batch_size = input_shape[0]
sequence_length = input_shape[1]
return (batch_size, sequence_length, self.hidden_size)
[docs]
class FastAttention(tf.keras.layers.Layer):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
def build(self, input_shape: Tuple[Optional[int], ...]) -> None:
super().build(input_shape)
[docs]
def call(self, x: tf.Tensor, mask: Optional[tf.Tensor] = None):
"""Fast attention
Parameters
----------
x : tf.Tensor
_description_
mask : tf.Tensor, optional
_description_, by default None
"""
return
[docs]
def get_config(self):
base_config = super().get_config()
return base_config