Source code for tfts.layers.moe_layer

"""Layer for :py:class:`~tfts.models.transformer`"""

from typing import Any, Callable, Dict, List, Optional, Tuple

import tensorflow as tf
from tensorflow.keras import activations, constraints, initializers, regularizers
from tensorflow.keras.layers import Dense


[docs] class MoELayer(tf.keras.layers.Layer): """Mixture of Experts layer for time series prediction. This layer implements a Mixture of Experts architecture where multiple expert networks specialize in different patterns of the time series, and a gating network determines which experts to use for each prediction. """ def __init__( self, num_experts: int, expert_hidden_size: int, gating_hidden_size: int, expert_activation: str = "relu", gating_activation: str = "softmax", kernel_initializer: str = "glorot_uniform", kernel_regularizer: Optional[str] = None, kernel_constraint: Optional[str] = None, use_bias: bool = True, bias_initializer: str = "zeros", trainable: bool = True, name: Optional[str] = None, ): super(MoELayer, self).__init__(trainable=trainable, name=name) self.num_experts = num_experts self.expert_hidden_size = expert_hidden_size self.gating_hidden_size = gating_hidden_size self.expert_activation = expert_activation self.gating_activation = gating_activation self.kernel_initializer = kernel_initializer self.kernel_regularizer = kernel_regularizer self.kernel_constraint = kernel_constraint self.use_bias = use_bias self.bias_initializer = bias_initializer def build(self, input_shape: Tuple[int, ...]): input_dim = int(input_shape[-1]) # Create expert networks self.experts = [] for i in range(self.num_experts): expert = Dense( self.expert_hidden_size, activation=self.expert_activation, kernel_initializer=self.kernel_initializer, kernel_regularizer=regularizers.get(self.kernel_regularizer), kernel_constraint=constraints.get(self.kernel_constraint), use_bias=self.use_bias, bias_initializer=self.bias_initializer, name=f"expert_{i}", ) self.experts.append(expert) # Create gating network self.gating_network = Dense( self.num_experts, activation=self.gating_activation, kernel_initializer=self.kernel_initializer, kernel_regularizer=regularizers.get(self.kernel_regularizer), kernel_constraint=constraints.get(self.kernel_constraint), use_bias=self.use_bias, bias_initializer=self.bias_initializer, name="gating_network", ) # Output projection layer self.output_projection = Dense( input_dim, kernel_initializer=self.kernel_initializer, kernel_regularizer=regularizers.get(self.kernel_regularizer), kernel_constraint=constraints.get(self.kernel_constraint), use_bias=self.use_bias, bias_initializer=self.bias_initializer, name="output_projection", ) super(MoELayer, self).build(input_shape)
[docs] def call(self, inputs: tf.Tensor) -> tf.Tensor: """Forward pass of the MoE layer. Args: inputs: Tensor of shape (batch_size, sequence_length, input_dim) Returns: output: Tensor of shape (batch_size, sequence_length, input_dim) """ # Get expert outputs expert_outputs = [] for expert in self.experts: expert_output = expert(inputs) # (batch_size, seq_length, expert_hidden_size) expert_outputs.append(expert_output) # Stack expert outputs expert_outputs = tf.stack(expert_outputs, axis=2) # (batch_size, seq_length, num_experts, expert_hidden_size) # Get gating weights gating_weights = self.gating_network(inputs) # (batch_size, seq_length, num_experts) gating_weights = tf.expand_dims(gating_weights, axis=-1) # (batch_size, seq_length, num_experts, 1) # Combine expert outputs using gating weights combined_output = tf.reduce_sum( expert_outputs * gating_weights, axis=2 ) # (batch_size, seq_length, expert_hidden_size) # Project back to input dimension output = self.output_projection(combined_output) # (batch_size, seq_length, input_dim) return output
[docs] def get_config(self) -> Dict[str, Any]: config = { "num_experts": self.num_experts, "expert_hidden_size": self.expert_hidden_size, "gating_hidden_size": self.gating_hidden_size, "expert_activation": self.expert_activation, "gating_activation": self.gating_activation, "kernel_initializer": self.kernel_initializer, "kernel_regularizer": self.kernel_regularizer, "kernel_constraint": self.kernel_constraint, "use_bias": self.use_bias, "bias_initializer": self.bias_initializer, } base_config = super(MoELayer, self).get_config() return dict(list(base_config.items()) + list(config.items()))