Source code for tfts.layers.unet_layer

"""Layer for :py:class:`~tfts.models.unet`"""

from typing import Any, Dict, Optional, Tuple

import tensorflow as tf
from tensorflow.keras.layers import Activation, Add, BatchNormalization, Conv1D, Dense, GlobalAveragePooling1D, Multiply


[docs] class ConvbrLayer(tf.keras.layers.Layer): """ 1D Convolution + BatchNorm + ReLU block. """ def __init__(self, units: int, kernel_size: int, strides: int, dilation: int, **kwargs): super().__init__(**kwargs) self.units = units self.kernel_size = kernel_size self.strides = strides self.dilation = dilation def build(self, input_shape: Tuple[Optional[int], ...]): self.conv1 = Conv1D( self.units, kernel_size=self.kernel_size, strides=self.strides, dilation_rate=self.dilation, padding="same" ) self.bn = BatchNormalization() self.relu = Activation("relu") super().build(input_shape)
[docs] def call(self, x: tf.Tensor) -> tf.Tensor: """ Forward pass for ConvbrLayer. Args: x: Input tensor of shape (batch, seq_len, features) Returns: Output tensor after conv, batchnorm, relu. """ x = self.conv1(x) x = self.bn(x) x = self.relu(x) return x
[docs] def get_config(self) -> Dict[str, Any]: config = { "units": self.units, "kernel_size": self.kernel_size, "strides": self.strides, "dilation": self.dilation, } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items()))
[docs] class SeBlock(tf.keras.layers.Layer): """ Squeeze-and-Excitation block for channel-wise attention. """ def __init__(self, units: int, **kwargs): super().__init__(**kwargs) self.units = units def build(self, input_shape: Tuple[Optional[int], ...]): self.pool = GlobalAveragePooling1D() self.fc1 = Dense(self.units // 8, activation="relu") self.fc2 = Dense(self.units, activation="sigmoid") super().build(input_shape)
[docs] def call(self, x: tf.Tensor) -> tf.Tensor: """ Forward pass for SeBlock. Args: x: Input tensor of shape (batch, seq_len, channels) Returns: Output tensor with channel-wise recalibration. """ input_tensor = x x = self.pool(x) x = self.fc1(x) x = self.fc2(x) x = tf.expand_dims(x, axis=1) # (batch, 1, channels) x_out = Multiply()([input_tensor, x]) return x_out
[docs] def get_config(self) -> Dict[str, Any]: config = {"units": self.units} base_config = super().get_config() return dict(list(base_config.items()) + list(config.items()))
[docs] class ReBlock(tf.keras.layers.Layer): """ Residual block with two Convbr layers and optional SE block. """ def __init__(self, units: int, kernel_size: int, strides: int, dilation: int, use_se: bool, **kwargs): super().__init__(**kwargs) self.units = units self.kernel_size = kernel_size self.strides = strides self.dilation = dilation self.use_se = use_se def build(self, input_shape: Tuple[Optional[int], ...]): self.conv_br1 = ConvbrLayer(self.units, self.kernel_size, self.strides, self.dilation) self.conv_br2 = ConvbrLayer(self.units, self.kernel_size, self.strides, self.dilation) if self.use_se: self.se_block = SeBlock(units=self.units) super().build(input_shape)
[docs] def call(self, x: tf.Tensor) -> tf.Tensor: """ Forward pass for ReBlock. Args: x: Input tensor. Returns: Output tensor after two Convbr layers, optional SE, and residual add. """ x_re = self.conv_br1(x) x_re = self.conv_br2(x_re) if self.use_se: x_re = self.se_block(x_re) x_re = Add()([x, x_re]) return x_re
[docs] def get_config(self) -> Dict[str, Any]: config = { "units": self.units, "kernel_size": self.kernel_size, "strides": self.strides, "dilation": self.dilation, "use_se": self.use_se, } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items()))