Source code for tfts.layers.mask_layer
"""Layer for :py:class:`~tfts.models.transformer`"""
import tensorflow as tf
from tensorflow.keras import activations, constraints, initializers, regularizers
[docs]
class CausalMask(tf.keras.layers.Layer):
"""Casual Mask is used for transformer decoder, used in first self-attention for decoder feature"""
def __init__(self, num_attention_heads, **kwargs):
super().__init__(**kwargs)
self.num_attention_heads = num_attention_heads
def call(self, inputs):
batch_size = tf.shape(inputs)[0]
seq_length = tf.shape(inputs)[1]
mask_shape = [batch_size, seq_length, seq_length] # for multi-heads split [B, 1, L, L]
mask_a = tf.linalg.band_part(tf.ones(mask_shape), 0, -1) # Upper triangular matrix of 0s and 1s
mask_b = tf.linalg.band_part(tf.ones(mask_shape), 0, 0) # Diagonal matrix of 0s and 1s
mask = tf.cast(mask_a - mask_b, dtype=tf.float32)
return mask
[docs]
def get_config(self):
config = {
"num_attention_heads": self.num_attention_heads,
}
base_config = super(CausalMask, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def compute_output_shape(self, input_shape):
batch_size = input_shape[0]
seq_length = input_shape[1]
return (batch_size, seq_length, seq_length)
[docs]
class ProbMask:
"""ProbMask for informer"""
def __init__(self, B, H, L, index, scores):
# B: batch_size, H: num_attention_heads, L: seq_length
mask = tf.ones([L, scores.shape[-1]], tf.float32)
mask = 1 - tf.linalg.band_part(mask, -1, 0)
mask_expanded = tf.broadcast_to(mask, [B, H, L, scores.shape[-1]])
# mask specific q based on reduced Q
mask_Q = tf.gather_nd(mask_expanded, index)
self._mask = tf.cast(tf.reshape(mask_Q, tf.shape(scores)), tf.bool)
@property
def mask(self):
return self._mask