"""
`WaveNet: A Generative Model for Raw Audio
<https://arxiv.org/abs/1609.03499>`_
"""
import logging
from typing import Dict, List, Optional
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Concatenate, Dense, Lambda, ReLU
from tfts.layers.cnn_layer import ConvTemp
from tfts.layers.dense_layer import DenseTemp
from .base import BaseConfig, BaseModel
logger = logging.getLogger(__name__)
[docs]
class WaveNetConfig(BaseConfig):
model_type: str = "wavenet"
def __init__(
self,
dilation_rates: List[int] = None,
kernel_sizes: List[int] = None,
filters: int = 128,
dense_hidden_size: int = 64,
scheduled_sampling: float = 1.0,
use_attention: bool = False,
attention_size: int = 64,
num_attention_heads: int = 2,
attention_probs_dropout_prob: float = 0.0,
**kwargs,
) -> None:
"""
Initializes the configuration for the WaveNet model with the specified parameters.
Args:
dilation_rates: List of dilation rates for the convolutional layers.
kernel_sizes: List of kernel sizes for the convolutional layers.
filters: The number of filters in the convolutional layers.
dense_hidden_size: The size of the dense hidden layer following the convolutional layers.
scheduled_sampling: Scheduled sampling ratio. 0 means teacher forcing, 1 means use last prediction
use_attention: Whether to use attention mechanism in the model.
attention_size: The size of the attention mechanism.
num_attention_heads: The number of attention heads.
attention_probs_dropout_prob: Dropout probability for attention probabilities.
"""
super(WaveNetConfig, self).__init__()
self.dilation_rates: List[int] = dilation_rates or [2**i for i in range(4)]
self.kernel_sizes: List[int] = kernel_sizes or [2] * 4
self.filters: int = filters
self.dense_hidden_size: int = dense_hidden_size
self.scheduled_sampling: float = scheduled_sampling
self.use_attention: bool = use_attention
self.attention_size: int = attention_size
self.num_attention_heads: int = num_attention_heads
self.attention_probs_dropout_prob: float = attention_probs_dropout_prob
[docs]
class WaveNet(BaseModel):
"""WaveNet model for time series"""
def __init__(self, predict_sequence_length: int = 1, config: Optional[WaveNetConfig] = None) -> None:
"""
Initializes the WaveNet model.
Args:
predict_sequence_length: Length of the prediction sequence.
config: Configuration object containing model parameters.
"""
super(WaveNet, self).__init__()
self.config = config or WaveNetConfig()
self.predict_sequence_length = predict_sequence_length
self.encoder = Encoder(
kernel_sizes=self.config.kernel_sizes,
dilation_rates=self.config.dilation_rates,
filters=self.config.filters,
dense_hidden_size=self.config.dense_hidden_size,
)
self.decoder = DecoderV1(
filters=self.config.filters,
dilation_rates=self.config.dilation_rates,
dense_hidden_size=self.config.dense_hidden_size,
predict_sequence_length=self.predict_sequence_length,
)
def __call__(
self,
inputs: tf.Tensor,
teacher: Optional[tf.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
"""
Forward pass for the WaveNet model.
Args:
inputs: Input tensor for the model.
teacher: Teacher tensor used for scheduled sampling.
output_hidden_states: Flag to output the hidden statues
return_dict: Flag to control the return type.
Returns:
Tensor containing the model output.
"""
x, encoder_feature, decoder_feature = self._prepare_3d_inputs(inputs, ignore_decoder_inputs=False)
encoder_state, encoder_outputs = self.encoder(encoder_feature)
decoder_outputs = self.decoder(
decoder_features=decoder_feature,
# imagine the first dim is the predict
decoder_init_input=x[:, -1, 0:1],
teacher=teacher,
encoder_outputs=encoder_outputs,
)
return decoder_outputs
[docs]
class Encoder(tf.keras.layers.Layer):
"""Encoder block for the WaveNet model."""
def __init__(
self, kernel_sizes: List[int], filters: int, dilation_rates: List[int], dense_hidden_size: int, **kwargs
) -> None:
"""
Initializes the encoder block.
Args:
kernel_sizes: List of kernel sizes for convolutional layers.
filters: Number of filters for convolutional layers.
dilation_rates: Dilation rates for the convolutions.
dense_hidden_size: Hidden size for the dense layers.
"""
super(Encoder, self).__init__(**kwargs)
self.filters = filters
self.conv_times = []
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilation_rates)):
self.conv_times.append(
ConvTemp(filters=2 * filters, kernel_size=kernel_size, causal=True, dilation_rate=dilation)
)
self.dense_time1 = DenseTemp(hidden_size=filters, activation="tanh", name="encoder_dense_time1")
self.dense_time2 = DenseTemp(hidden_size=filters + filters, name="encoder_dense_time2")
self.dense_time3 = DenseTemp(hidden_size=dense_hidden_size, activation="relu", name="encoder_dense_time3")
self.dense_time4 = DenseTemp(hidden_size=1, name="encoder_dense_time_4")
def call(self, x: tf.Tensor):
inputs = self.dense_time1(inputs=x)
skip_outputs = []
conv_inputs = [inputs]
for conv_time in self.conv_times:
dilated_conv = conv_time(inputs)
split_layer = Lambda(lambda x: tf.split(x, 2, axis=2))
conv_filter, conv_gate = split_layer(dilated_conv)
dilated_conv = Lambda(lambda x: tf.nn.tanh(x[0]) * tf.nn.sigmoid(x[1]))([conv_filter, conv_gate])
outputs = self.dense_time2(inputs=dilated_conv)
split_layer2 = Lambda(lambda x: tf.split(x, [self.filters, self.filters], axis=2))
skips, residuals = split_layer2(outputs)
inputs += residuals
conv_inputs.append(inputs) # batch_size * time_sequence_length * filters
skip_outputs.append(skips)
concat_layer = Concatenate(axis=2)
concatenated = concat_layer(skip_outputs)
relu_layer = ReLU()
skip_outputs = relu_layer(concatenated)
# skip_outputs = tf.nn.relu(tf.concat(skip_outputs, axis=2))
h = self.dense_time3(skip_outputs)
# [batch_size, time_sequence_length, filters] * time_sequence_length
y_hat = self.dense_time4(h)
return y_hat, conv_inputs[:-1]
[docs]
class DecoderV1(tf.keras.layers.Layer):
"""Decoder block for WaveNet V1."""
def __init__(
self,
filters: int,
dilation_rates: List[int],
dense_hidden_size: int,
predict_sequence_length: int = 24,
**kwargs,
) -> None:
"""
Initializes the decoder block.
Args:
filters: Number of filters for convolutional layers.
dilation_rates: Dilation rates for convolutions.
dense_hidden_size: Size of the dense hidden layer.
predict_sequence_length: Length of the predicted sequence.
"""
super().__init__(**kwargs)
self.filters: int = filters
self.predict_sequence_length = predict_sequence_length
self.dilation_rates = dilation_rates
self.dense_hidden_size = dense_hidden_size
def build(self, input_shape, **kwargs):
batch_size = input_shape[0]
decoder_input_size = input_shape[-1] + 1
self.dense1 = Dense(self.filters, activation="tanh")
self.dense1.build([batch_size, decoder_input_size])
self.dense2 = Dense(2 * self.filters, use_bias=True)
self.dense2.build([batch_size, self.filters])
self.dense3 = Dense(2 * self.filters, use_bias=False)
self.dense3.build([batch_size, self.filters])
self.dense4 = Dense(2 * self.filters)
self.dense4.build([batch_size, self.filters])
total_skips = self.filters * len(self.dilation_rates)
self.dense5 = Dense(self.dense_hidden_size, activation="relu")
self.dense5.build([batch_size, total_skips])
self.dense6 = Dense(1)
self.dense6.build([batch_size, self.dense_hidden_size])
super().build(input_shape)
[docs]
def call(
self,
decoder_features,
decoder_init_input,
encoder_outputs,
teacher: Optional[tf.Tensor] = None,
scheduled_sampling: float = 0.0,
training: Optional[bool] = None,
**kwargs: Dict,
):
"""
Forward pass for the decoder block.
Args:
decoder_features: Tensor containing decoder features.
decoder_init_input: Initial input for the decoder.
encoder_outputs: List of encoder outputs.
teacher: Optional tensor for teacher forcing.
scheduled_sampling: Probability of using teacher forcing.
training: Whether the model is in training mode.
Returns:
Decoder output tensor.
"""
decoder_outputs = []
prev_output = decoder_init_input # the initial input for decoder
for i in range(self.predict_sequence_length):
if training:
p = np.random.uniform(low=0, high=1, size=1)[0]
if teacher is not None and p > scheduled_sampling:
this_input = teacher[:, i : i + 1]
else:
this_input = prev_output
else:
this_input = prev_output
if decoder_features is not None:
# this_input = tf.concat([this_input, decoder_features[:, i]], axis=-1)
this_input = Concatenate(axis=-1)([this_input, decoder_features[:, i]])
x = self.dense1(this_input)
skip_outputs = []
for i, dilation in enumerate(self.dilation_rates):
safe_dilation = min(dilation, encoder_outputs[i].shape[1])
if dilation > encoder_outputs[i].shape[1]:
logger.warning(
f"Dilation {dilation} exceeds context length {encoder_outputs[i].shape[1]}. "
"Using {safe_dilation} instead."
)
dilation = safe_dilation
state = encoder_outputs[i][:, -dilation, :]
# use 2 dense layer to calculate a kernel=2 convolution
dilated_conv = self.dense2(state) + self.dense3(x)
# conv_filter, conv_gate = tf.split(dilated_conv, 2, axis=1)
split_layer = Lambda(lambda x: tf.split(x, 2, axis=1))
conv_filter, conv_gate = split_layer(dilated_conv)
# dilated_conv = tf.nn.tanh(conv_filter) * tf.nn.sigmoid(conv_gate)
dilated_conv = Lambda(lambda x: tf.nn.tanh(x[0]) * tf.nn.sigmoid(x[1]))([conv_filter, conv_gate])
out = self.dense4(dilated_conv)
# skip, residual = tf.split(out, 2, axis=1)
split_layer = Lambda(lambda x: tf.split(x, [self.filters, self.filters], axis=1))
skips, residuals = split_layer(out)
x += residuals
# encoder_outputs[i] = tf.concat([encoder_outputs[i], tf.expand_dims(x, 1)], axis=1)
expand = Lambda(lambda t: tf.expand_dims(t, axis=1))
encoder_outputs[i] = Concatenate(1)([encoder_outputs[i], expand(x)])
skip_outputs.append(skips)
# skip_outputs = tf.nn.relu(tf.concat(skip_outputs, axis=1))
concatenated = Concatenate(axis=1)(skip_outputs)
skip_outputs = ReLU()(concatenated)
skip_outputs = self.dense5(skip_outputs)
this_output = self.dense6(skip_outputs)
decoder_outputs.append(this_output)
# decoder_outputs = tf.concat(decoder_outputs, axis=1)
decoder_outputs = Concatenate(1)(decoder_outputs)
expand = Lambda(lambda t: tf.expand_dims(t, axis=-1))
return expand(decoder_outputs)
[docs]
def get_config(self):
config = super().get_config()
config.update(
{
"filters": self.filters,
"dilation_rates": self.dilation_rates,
"dense_hidden_size": self.dense_hidden_size,
"predict_sequence_length": self.predict_sequence_length,
}
)
return config
def compute_output_shape(self, input_shape):
batch_size = input_shape[0]
return (batch_size, self.predict_sequence_length, 1)
[docs]
class DecoderV2(tf.keras.layers.Layer):
"""Decoder need avoid future data leaks"""
def __init__(
self,
filters: int,
dilation_rates: List[int],
dense_hidden_size: int,
predict_sequence_length: int = 24,
**kwargs,
):
super().__init__(**kwargs)
self.filters = filters
self.dilation_rates = dilation_rates
self.predict_sequence_length = predict_sequence_length
self.dense_hidden_size = dense_hidden_size
def build(self, input_shape):
super().build(input_shape)
self.dense_1 = Dense(self.filters, activation="tanh", name="decoder_dense_1")
self.dense_2 = Dense(2 * self.filters, name="decoder_dense_2")
self.dense_3 = Dense(2 * self.filters, use_bias=False, name="decoder_dense_3")
self.dense_4 = Dense(2 * self.filters, name="decoder_dense_4")
self.dense_5 = Dense(self.dense_hidden_size, activation="relu", name="decoder_dense_5")
self.dense_6 = Dense(1, name="decoder_dense_6")
[docs]
def call(
self,
decoder_features: tf.Tensor,
decoder_init_input: tf.Tensor,
encoder_states: tf.Tensor,
teacher: Optional[tf.Tensor] = None,
):
"""
Forward pass for the decoder block v2.
Args:
decoder_features: Tensor containing decoder features.
decoder_init_input: Initial input for the decoder.
encoder_states: List of encoder outputs.
teacher: Optional tensor for teacher forcing.
Returns:
Decoder output tensor.
"""
def cond_fn(time, prev_output, decoder_output_ta):
return time < self.predict_sequence_length
def body(time, prev_output, decoder_output_ta):
if time == 0 or teacher is None:
current_input = prev_output
else:
current_input = teacher[:, time - 1, :]
if decoder_features is not None:
current_feature = decoder_features[:, time, :]
current_input = tf.concat([current_input, current_feature], axis=1)
inputs = self.dense_1(current_input)
skip_outputs = []
for i, dilation in enumerate(self.dilation_rates):
state = encoder_states[i][:, -dilation, :]
dilated_conv = self.dense_2(state) + self.dense_3(inputs)
conv_filter, conv_gate = tf.split(dilated_conv, 2, axis=1)
dilated_conv = tf.nn.tanh(conv_filter) * tf.nn.sigmoid(conv_gate)
outputs = self.dense_4(dilated_conv)
skips, residuals = tf.split(outputs, [self.filters, self.filters], axis=1)
inputs += residuals
encoder_states[i] = tf.concat([encoder_states[i], tf.expand_dims(inputs, 1)], axis=1)
skip_outputs.append(skips)
skip_outputs = tf.nn.relu(tf.concat(skip_outputs, axis=1))
h = self.dense_5(skip_outputs)
y_hat = self.dense_6(h)
decoder_output_ta = decoder_output_ta.write(time, y_hat)
return time + 1, y_hat, decoder_output_ta
loop_init = [
tf.constant(0, dtype=tf.int32),
decoder_init_input,
tf.TensorArray(dtype=tf.float32, size=self.predict_sequence_length),
]
_, _, decoder_outputs_ta = tf.while_loop(cond=cond_fn, body=body, loop_vars=loop_init)
decoder_outputs = decoder_outputs_ta.stack()
decoder_outputs = tf.transpose(decoder_outputs, [1, 0, 2])
return decoder_outputs