Source code for tfts.models.unet

"""
`U-Net: Convolutional Networks for Biomedical Image Segmentation
<https://arxiv.org/abs/1505.04597>`_
"""

from typing import List, Optional, Tuple

import tensorflow as tf
from tensorflow.keras.layers import (
    AveragePooling1D,
    Concatenate,
    Conv1D,
    Dense,
    Dropout,
    Lambda,
    LayerNormalization,
    MultiHeadAttention,
    UpSampling1D,
)

from tfts.layers.embed_layer import DataEmbedding
from tfts.layers.unet_layer import ConvbrLayer, ReBlock, SeBlock

from ..layers.util_layer import ShapeLayer
from .base import BaseConfig, BaseModel


[docs] class UnetConfig(BaseConfig): model_type: str = "unet" def __init__( self, units: int = 64, kernel_size: int = 2, depth: int = 2, pool_sizes: Tuple[int, int] = (2, 4), upsampling_factors: Tuple[int, int, int] = (2, 2, 2), num_attention_heads: int = 4, attention_probs_dropout_prob: float = 0.1, hidden_dropout_prob: float = 0.1, use_residual: bool = False, use_attention: bool = False, use_se: bool = False, use_layer_norm: bool = False, **kwargs, ): super(UnetConfig, self).__init__() self.units = units self.kernel_size = kernel_size self.depth = depth self.pool_sizes = pool_sizes self.upsampling_factors = upsampling_factors self.num_attention_heads = num_attention_heads self.attention_probs_dropout_prob = attention_probs_dropout_prob self.hidden_dropout_prob = hidden_dropout_prob self.use_residual = use_residual self.use_attention = use_attention self.use_se = use_se self.use_layer_norm = use_layer_norm self.update(kwargs)
[docs] class Unet(BaseModel): """Unet model for sequence-to-sequence prediction tasks.""" def __init__(self, predict_sequence_length: int = 1, config: Optional[UnetConfig] = None): super(Unet, self).__init__() self.config = config or UnetConfig() self.predict_sequence_length = predict_sequence_length # Validate sequence length requirements min_sequence_length = ( self.config.pool_sizes[0] * self.config.pool_sizes[1] * self.config.upsampling_factors[0] * self.config.upsampling_factors[1] * self.config.upsampling_factors[2] ) if predict_sequence_length > min_sequence_length: raise ValueError( f"predict_sequence_length ({predict_sequence_length}) must be less than or equal to " f"the minimum sequence length ({min_sequence_length}) determined by pooling and upsampling factors. " f"Current pool_sizes={self.config.pool_sizes}, upsampling_factors={self.config.upsampling_factors}" ) # Input embedding self.embedding = DataEmbedding(self.config.units, positional_type="positional encoding") # Pooling layers self.avg_pool1 = AveragePooling1D(pool_size=self.config.pool_sizes[0]) self.avg_pool2 = AveragePooling1D(pool_size=self.config.pool_sizes[1]) # Encoder and decoder self.encoder = Encoder( units=self.config.units, kernel_size=self.config.kernel_size, depth=self.config.depth, use_attention=self.config.use_attention, use_residual=self.config.use_residual, use_se=self.config.use_se, use_layer_norm=self.config.use_layer_norm, num_attention_heads=self.config.num_attention_heads, attention_probs_dropout_prob=self.config.attention_probs_dropout_prob, hidden_dropout_prob=self.config.hidden_dropout_prob, ) self.decoder = Decoder( upsampling_factors=self.config.upsampling_factors, units=self.config.units, kernel_size=self.config.kernel_size, predict_seq_length=predict_sequence_length, use_attention=self.config.use_attention, use_residual=self.config.use_residual, use_se=self.config.use_se, use_layer_norm=self.config.use_layer_norm, num_attention_heads=self.config.num_attention_heads, attention_probs_dropout_prob=self.config.attention_probs_dropout_prob, hidden_dropout_prob=self.config.hidden_dropout_prob, ) # Output projection self.output_projection = Dense(1) def __call__( self, x: tf.Tensor, training: bool = True, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): """Forward pass through the model. Args: x: Input tensor of shape (batch_size, sequence_length, num_features). training: Boolean flag for training mode. output_hidden_states: Whether to output hidden states. return_dict: Whether to return a dictionary of outputs. Returns: Tensor: Output predictions of shape (batch_size, predict_sequence_length, 1). """ # Validate input sequence length # _, input_sequence_length, _ = ShapeLayer()(x) # min_sequence_length = ( # self.config.pool_sizes[0] # * self.config.pool_sizes[1] # * self.config.upsampling_factors[0] # * self.config.upsampling_factors[1] # * self.config.upsampling_factors[2] # ) # if input_sequence_length < min_sequence_length: # raise ValueError( # f"Input sequence length ({input_sequence_length}) must be greater than or equal to " # f"the minimum sequence length ({min_sequence_length}) determined by pooling and upsampling factors. " # f"Current pool_sizes={self.config.pool_sizes}, upsampling_factors={self.config.upsampling_factors}" # ) # Prepare inputs x, encoder_feature, decoder_feature = self._prepare_3d_inputs(x, ignore_decoder_inputs=False) # Embed inputs x = self.embedding(encoder_feature) # Apply pooling pool1 = self.avg_pool1(x) pool2 = self.avg_pool2(x) # Encode encoder_output = self.encoder([x, pool1, pool2], training=training) # Decode decoder_output = self.decoder(encoder_output, training=training) # Project to output output = self.output_projection(decoder_output) # Slice to prediction length output = output[:, -self.predict_sequence_length :, :] if return_dict: return {"output": output} return output
[docs] class Encoder(tf.keras.layers.Layer): """Encoder component for the Unet model.""" def __init__( self, units: int = 64, kernel_size: int = 2, depth: int = 1, use_attention: bool = False, use_residual: bool = False, use_se: bool = False, use_layer_norm: bool = False, num_attention_heads: int = 4, attention_probs_dropout_prob: float = 0.1, hidden_dropout_prob: float = 0.1, **kwargs, ): super().__init__(**kwargs) self.units = units self.kernel_size = kernel_size self.depth = depth self.use_attention = use_attention self.use_residual = use_residual self.use_se = use_se self.use_layer_norm = use_layer_norm # First level layers self.conv_br1 = ConvbrLayer(units, kernel_size, 1, 1) self.re_blocks1 = [ReBlock(units, kernel_size, 1, 1, use_se=use_se) for _ in range(depth)] # Second level layers self.conv_br2 = ConvbrLayer(units * 2, kernel_size, 2, 1) self.re_blocks2 = [ReBlock(units * 2, kernel_size, 1, 1, use_se=use_se) for _ in range(depth)] # Third level layers self.conv_br3 = ConvbrLayer(units * 3, kernel_size, 2, 1) self.re_blocks3 = [ReBlock(units * 3, kernel_size, 1, 1, use_se=use_se) for _ in range(depth)] # Fourth level layers self.conv_br4 = ConvbrLayer(units * 4, kernel_size, 2, 1) self.re_blocks4 = [ReBlock(units * 4, kernel_size, 1, 1, use_se=use_se) for _ in range(depth)] # Attention layers if use_attention: self.attention_layers = [ MultiHeadAttention(num_heads=num_attention_heads, key_dim=units, dropout=attention_probs_dropout_prob) for _ in range(depth) ] # Layer normalization if use_layer_norm: self.layer_norms = [LayerNormalization() for _ in range(depth)] # Dropout self.dropout = Dropout(hidden_dropout_prob)
[docs] def call(self, inputs: tf.Tensor, training: bool = True): """Forward pass through the encoder. Args: inputs: Tuple containing the input tensor and pooled tensors. training: Whether the model is in training mode. Returns: Tuple: Encoder outputs. """ x, pool1, pool2 = inputs # First level x = self.conv_br1(x) # => batch_size * sequence_length * units for i in range(self.depth): residual = x x = self.re_blocks1[i](x) if self.use_attention: x = self.attention_layers[i](x, x, x) if self.use_layer_norm: x = self.layer_norms[i](x) if self.use_residual: x = x + residual x = self.dropout(x, training=training) out_0 = x # => batch_size * sequence_length * units # Second level x = self.conv_br2(x) for i in range(self.depth): residual = x x = self.re_blocks2[i](x) if self.use_attention: x = self.attention_layers[i](x, x, x) if self.use_layer_norm: x = self.layer_norms[i](x) out_1 = x # => batch_size * (sequence/2) * (units * 2) # Third level with pool1 x = Concatenate()([x, pool1]) x = self.conv_br3(x) for i in range(self.depth): residual = x x = self.re_blocks3[i](x) if self.use_attention: x = self.attention_layers[i](x, x, x) if self.use_layer_norm: x = self.layer_norms[i](x) if self.use_residual: x = x + residual out_2 = x # => batch_size * (sequence/2), (units*3) # Fourth level with pool2 x = Concatenate()([x, pool2]) x = self.conv_br4(x) for i in range(self.depth): residual = x x = self.re_blocks4[i](x) if self.use_attention: x = self.attention_layers[i](x, x, x) if self.use_layer_norm: x = self.layer_norms[i](x) if self.use_residual: x = x + residual x = self.dropout(x, training=training) out3 = x return [out_0, out_1, out_2, out3]
[docs] class Decoder(tf.keras.layers.Layer): """Decoder component for the Unet model.""" def __init__( self, upsampling_factors: Tuple[int, int, int], units: int = 64, kernel_size: int = 2, predict_seq_length: int = 1, use_attention: bool = True, use_residual: bool = True, use_se: bool = True, use_layer_norm: bool = True, num_attention_heads: int = 4, attention_probs_dropout_prob: float = 0.1, hidden_dropout_prob: float = 0.1, **kwargs, ): super().__init__(**kwargs) self.upsampling_factors = upsampling_factors self.units = units self.kernel_size = kernel_size self.predict_seq_length = predict_seq_length self.use_attention = use_attention self.use_residual = use_residual self.use_se = use_se self.use_layer_norm = use_layer_norm # Upsampling layers self.upsampling1 = UpSampling1D(upsampling_factors[0]) self.upsampling2 = UpSampling1D(upsampling_factors[1]) self.upsampling3 = UpSampling1D(upsampling_factors[2]) # Convolution layers self.conv_br1 = ConvbrLayer(units * 3, kernel_size, 1, 1) self.conv_br2 = ConvbrLayer(units * 2, kernel_size, 1, 1) self.conv_br3 = ConvbrLayer(units, kernel_size, 1, 1) # Attention layers if use_attention: self.attention_layers = [ MultiHeadAttention(num_heads=num_attention_heads, key_dim=units, dropout=attention_probs_dropout_prob) for _ in range(3) # One for each upsampling level ] # Layer normalization if use_layer_norm: self.layer_norms = [LayerNormalization() for _ in range(3)] # Dropout self.dropout = Dropout(hidden_dropout_prob)
[docs] def call(self, inputs: Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor], training: bool = True) -> tf.Tensor: """Forward pass through the decoder. Args: inputs: Tuple containing encoder outputs. training: Whether the model is in training mode. Returns: Tensor: Decoder output. """ out_0, out_1, out_2, x = inputs # First upsampling x = self.upsampling1(x) x = Concatenate()([x, out_2]) x = self.conv_br1(x) if self.use_attention: x = self.attention_layers[0](x, x, x) if self.use_layer_norm: x = self.layer_norms[0](x) x = self.dropout(x, training=training) # Second upsampling x = self.upsampling2(x) x = Concatenate()([x, out_1]) x = self.conv_br2(x) if self.use_attention: x = self.attention_layers[1](x, x, x) if self.use_layer_norm: x = self.layer_norms[1](x) x = self.dropout(x, training=training) # Third upsampling x = self.upsampling3(x) x = Concatenate()([x, out_0]) x = self.conv_br3(x) if self.use_attention: x = self.attention_layers[2](x, x, x) if self.use_layer_norm: x = self.layer_norms[2](x) x = self.dropout(x, training=training) return x