Source code for tfts.trainer

"""tfts Trainer"""

from collections.abc import Iterable
from contextlib import nullcontext
import logging
import os
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.layers import Input

from .constants import CONFIG_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, TFTS_HOME, TFTS_HUB_CACHE
from .models.base import BaseModel
from .training_args import TrainingArguments

__all__ = ["Trainer", "KerasTrainer", "Seq2seqKerasTrainer"]


logger = logging.getLogger(__name__)


[docs] class BaseTrainer(object): """Trainer for pipeline""" def __init__( self, model: Union[tf.keras.Model, "BaseModel"], args: Optional[TrainingArguments] = None, strategy: Optional[tf.distribute.Strategy] = None, **kwargs, ): self.model = model self.config = model.config if hasattr(model, "config") else None self.args = args or TrainingArguments(output_dir=TFTS_HUB_CACHE) self.strategy = strategy # with self.get_strategy_scope(strategy): # self.model = self._setup_model(model) # self.loss_fn = loss_fn # self.metrics = metrics or [] # self.optimizer = optimizer or self._create_optimizer() # self.lr_scheduler = lr_scheduler or self._create_lr_scheduler() # # # Training state # self.global_step = tf.Variable(0, trainable=False, dtype=tf.int32) # if self.args.fp16: # self._setup_mixed_precision() def evaluate(self): pass def get_train_dataloader(self): return def get_eval_dataloader(self): return def get_test_dataloader(self): return def get_learning_rates(self): return def create_accelerator_and_postprocess(self): return def get_strategy_scope(self): return self.strategy.scope() if self.strategy else nullcontext() def _create_optimizer(self) -> tf.keras.optimizers.Optimizer: """Create optimizer with specified parameters.""" return tf.keras.optimizers.Adam( learning_rate=self.args.learning_rate, beta_1=self.args.adam_beta1, beta_2=self.args.adam_beta2, epsilon=self.args.adam_epsilon, weight_decay=self.args.weight_decay, ) def _create_lr_scheduler(self) -> Optional[tf.keras.optimizers.schedules.LearningRateSchedule]: """Create learning rate scheduler based on arguments.""" if self.args.lr_scheduler_type == "linear": return tf.keras.optimizers.schedules.PolynomialDecay( initial_learning_rate=self.args.learning_rate, decay_steps=self.args.max_steps if self.args.max_steps > 0 else self.args.num_train_epochs, end_learning_rate=0, power=1.0, ) return None def _setup_mixed_precision(self) -> None: """Configure mixed precision training.""" policy = tf.keras.mixed_precision.Policy("mixed_float16") tf.keras.mixed_precision.set_global_policy(policy) # def _setup_ema(self) -> None: # """Configure Exponential Moving Average if enabled.""" # self.ema = None # if self.config.use_ema: # self.ema = tf.train.ExponentialMovingAverage(self.config.ema_decay) def get_inputs(self, train_dataset): if isinstance(train_dataset, tf.data.Dataset): # choose the first batch x = next(iter(train_dataset.take(1).as_numpy_iterator()))[0] inputs = self._prepare_inputs_for_model(x) elif isinstance(train_dataset, tf.keras.utils.Sequence): x, _ = train_dataset[0] inputs = self._prepare_inputs_for_model(x) elif isinstance(train_dataset, (list, tuple)): x = train_dataset[0] inputs = self._prepare_inputs_for_model(x) else: raise ValueError("Unsupported dataset type. Expected tf.data.Dataset, keras.utils.Sequence, or list/tuple.") return inputs def _prepare_inputs_for_model( self, x: Union[np.ndarray, pd.DataFrame] ) -> Union[Dict[str, tf.keras.layers.Input], List[tf.keras.layers.Input], tf.keras.layers.Input]: """ Prepares the input layer(s) based on the shape of the provided data. Args: x: Input data (either a NumPy array or a Pandas DataFrame). Returns: The corresponding Keras Input layers. """ if isinstance(x, dict): logger.debug("Preparing inputs from dict") return {key: Input(shape=item.shape[1:], name=key) for key, item in x.items()} elif isinstance(x, (list, tuple)): logger.debug("Preparing inputs from list or tuple") return [Input(shape=item.shape[1:], name=f"input_{i}") for i, item in enumerate(x)] else: logger.debug("Preparing single input") return Input(shape=x.shape[1:], name="input") def _save(self, output_dir: Optional[str] = None): output_dir = output_dir if output_dir is not None else TFTS_HOME os.makedirs(output_dir, exist_ok=True) logger.info(f"Saving model checkpoint to {output_dir}") # save_model = self.model.model if hasattr(self.model, "model") else self.model # self.model.save_pretrained(output_dir) # model save (due to after build_model, the model will be replaced to a tf.keras.model) save_directory = output_dir if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return os.makedirs(save_directory, exist_ok=True) self.config.architectures = [self.model.__class__.__name__[2:]] self.config.save_pretrained(save_directory) weights_file = os.path.join(save_directory, TF2_WEIGHTS_NAME) # Or the appropriate extension try: self.model.save_weights(weights_file) logging.info(f"Model weights successfully saved in {weights_file}") except Exception as e: logging.error(f"Failed to save model weights to {weights_file}: {e}") return
[docs] class KerasTrainer(BaseTrainer): """Keras trainer from tf.keras""" def __init__( self, model: Union[tf.keras.Model, "BaseModel"], strategy: Optional[tf.distribute.Strategy] = None, args: Optional[TrainingArguments] = None, **kwargs: Dict[str, object], ) -> None: """ Initializes the trainer with the model, loss function, optimizer, and other optional parameters. Args: model: A Keras Model or Sequential instance to train. strategy: Optional distribution strategy for multi-GPU or multi-node training. **kwargs: Additional arguments that are passed to the instance as attributes. """ super().__init__(model, args, strategy, **kwargs) self.model = model self.config = model.config if hasattr(model, "config") else None for key, value in kwargs.items(): setattr(self, key, value)
[docs] def train( self, train_dataset: Union[tf.data.Dataset, List[tf.Tensor], Tuple[tf.Tensor, tf.Tensor]], valid_dataset: Optional[Union[tf.data.Dataset, List[tf.Tensor], Tuple[tf.Tensor, tf.Tensor]]] = None, loss_fn: Union[Callable, tf.keras.losses.Loss, str] = "mse", optimizer: Union[tf.keras.optimizers.Optimizer, str, Dict] = "adam", epochs: int = 10, batch_size: int = 64, steps_per_epoch: Optional[int] = None, metrics: Optional[Union[List[tf.keras.metrics.Metric], List[str]]] = None, callbacks: Optional[List[tf.keras.callbacks.Callback]] = None, run_eagerly: bool = True, verbose: int = 1, **kwargs: Dict[str, object], ) -> tf.keras.callbacks.History: """ Trains the model on the provided dataset. Args: train_dataset: A tf.data.Dataset or list/tuple of tensors (x_train, y_train). valid_dataset: A tf.data.Dataset or list/tuple of tensors (x_valid, y_valid), optional. loss_fn: A callable or Keras loss function. Default is MeanSquaredError. optimizer: A Keras optimizer instance. Default is Adam with learning rate 0.003. epochs: Number of epochs to train for. Default is 10. batch_size: Number of samples per batch. Default is 64. steps_per_epoch: Number of steps per epoch. Optional. metrics: List of metrics for monitoring during training. Optional. callbacks: List of keras callbacks during training. Optional. run_eagerly: Whether to run eagerly. Default is True. verbose: Verbosity level. Default is 1. **kwargs: Additional keyword arguments for callbacks. Returns: A History object containing training logs. """ if not callbacks: callbacks: List[tf.keras.callbacks.Callback] = [] with self.get_strategy_scope(): # if lr_scheduler: # just set Optimizer(learning_rate=lr_scheduler) # callbacks.append(tf.keras.callbacks.LearningRateScheduler(lr_scheduler, verbose=True)) if isinstance(optimizer, (str, dict)): optimizer = tf.keras.optimizers.get(optimizer) if not isinstance(self.model, tf.keras.Model): inputs = self.get_inputs(train_dataset) if "build_model" not in dir(self.model): raise TypeError("Trainer model should either be `tf.keras.Model` or has `build_model()` method") self.model = self.model.build_model(inputs=inputs) self.model.compile(loss=loss_fn, optimizer=optimizer, metrics=metrics, run_eagerly=run_eagerly) trainable_params = np.sum([tf.keras.backend.count_params(w) for w in self.model.trainable_weights]) tf.print(f"Trainable parameters: {trainable_params}") if isinstance(train_dataset, (list, tuple)): x_train, y_train = train_dataset history = self.model.fit( x_train, y_train, validation_data=valid_dataset, steps_per_epoch=steps_per_epoch, epochs=epochs, batch_size=batch_size, verbose=verbose, callbacks=callbacks, ) else: history = self.model.fit( train_dataset, validation_data=valid_dataset, steps_per_epoch=steps_per_epoch, epochs=epochs, batch_size=batch_size, verbose=verbose, callbacks=callbacks, ) return history
def fit(self, **params): return self.train(**params) def predict(self, x_test: tf.Tensor) -> tf.Tensor: return self.model(x_test) def get_model(self) -> tf.keras.Model: return self.model def save_model(self, output_dir: Optional[str] = None): # save the model, checkpoint_dir if you use Checkpoint callback to save your best weights output_dir = TFTS_HOME if output_dir is None else output_dir self._save(output_dir) def plot(self, history, true: np.ndarray, pred: np.ndarray): import matplotlib.pyplot as plt train_length = history.shape[1] pred_length = true.shape[1] example = np.random.choice(range(history.shape[0])) plt.plot(range(train_length), history[example, :, 0], label="History") plt.plot(range(train_length, train_length + pred_length), true[example, :, 0], label="True") plt.plot(range(train_length, train_length + pred_length), pred[example, :, 0], label="Predicted") plt.legend()
[docs] class Seq2seqKerasTrainer(KerasTrainer): """As the transformers forum mentioned: https://discuss.huggingface.co/t/trainer-vs-seq2seqtrainer/3145/2 Seq2SeqTrainer is mostly about predict_with_generate.""" def __init__(self, *args, **kwargs): super(Seq2seqKerasTrainer, self).__init__(*args, **kwargs)
[docs] class Trainer(object): """Custom trainer for tensorflow with support for CPU, GPU, and multi-GPU.""" def __init__( self, model, strategy: Optional[tf.distribute.Strategy] = None, **kwargs: Dict[str, Any], ) -> None: self.model = model self.strategy = strategy for key, value in kwargs.items(): setattr(self, key, value)
[docs] def train( self, train_loader: Union[tf.data.Dataset, Generator], valid_loader: Union[tf.data.Dataset, Generator, None] = None, loss_fn: Union[Callable] = tf.keras.losses.MeanSquaredError(), optimizer: tf.keras.optimizers.Optimizer = tf.keras.optimizers.Adam(0.003), lr_scheduler: Optional[tf.keras.optimizers.schedules.LearningRateSchedule] = None, epochs: int = 10, learning_rate: float = 3e-4, verbose: int = 1, eval_metric: Union[Callable, List[Callable], None] = None, model_dir: Optional[str] = None, use_ema: bool = False, stop_no_improve_epochs: Optional[int] = None, max_grad_norm: float = 5.0, transform: Optional[Callable] = None, ) -> None: """ Trains the model using the provided data loaders. Parameters ---------- train_loader : Union[tf.data.Dataset, Generator] The training data loader, which can be a `tf.data.Dataset` or a Python generator. valid_loader : Union[tf.data.Dataset, Generator, None], optional The validation data loader, by default None. epochs : int, optional The number of epochs to train the model, by default 10. learning_rate : float, optional The initial learning rate for the optimizer, by default 3e-4. verbose : int, optional The verbosity level (0 = silent, 1 = progress bar, 2 = one line per epoch), by default 1. eval_metric : Union[Callable, List[Callable], None], optional The evaluation metric(s) to use for validation, by default None. model_dir : Optional[str], optional The directory to save the model weights, by default "../weights". use_ema : bool, optional Whether to use exponential moving average (EMA) for the model weights, by default False. stop_no_improve_epochs : Optional[int], optional If provided, training will stop if the validation metric does not improve for the specified number of epochs, by default None. max_grad_norm : float, optional the max gradient while backprop. transform : Optional[Callable], optional A function to transform the data before feeding it to the model, by default None. """ self.loss_fn = loss_fn self.optimizer = optimizer self.lr_scheduler = lr_scheduler self.learning_rate = learning_rate self.eval_metric = eval_metric if isinstance(eval_metric, Iterable) else [eval_metric] self.use_ema = use_ema self.transform = transform self.max_grad_norm = max_grad_norm self.global_step = tf.Variable(0, trainable=False, dtype=tf.int32) if use_ema: self.ema = tf.train.ExponentialMovingAverage(0.9).apply(self.model.trainable_variables) if model_dir is None: model_dir = TFTS_HUB_CACHE if stop_no_improve_epochs is not None: no_improve_epochs: int = 0 best_metric: float = float("inf") if not isinstance(self.model, tf.keras.Model): if "build_model" not in dir(self.model): raise TypeError("Trainer model should either be tf.keras.Model or has the build_model method") x = list(train_loader.take(1).as_numpy_iterator())[0][0] if isinstance(x, dict): inputs = {key: Input(item.shape[1:]) for key, item in x.items()} else: inputs = Input(x.shape[1:]) self.model = self.model.build_model(inputs=inputs) for epoch in range(epochs): train_loss, train_scores = self.train_loop(train_loader) log_str = f"Epoch: {epoch + 1}, Train Loss: {train_loss:.4f}" # noqa if valid_loader is not None: valid_loss, valid_scores = self.valid_loop(valid_loader) log_str += f", Valid Loss: {valid_loss:.4f}" # noqa log_str + ",".join([" Valid Metrics{}: {:.4f}".format(i, me) for i, me in enumerate(valid_scores)]) if (stop_no_improve_epochs is not None) and (eval_metric is not None): if valid_scores[0] >= best_metric: best_metric = valid_scores[0] no_improve_epochs = 0 else: no_improve_epochs += 1 if no_improve_epochs >= stop_no_improve_epochs: logger.info("Tried the best, no improved and stop training") break logger.info(log_str)
# self.export_model(model_dir, only_pb=True) # save the model def fit(self, **params): return self.train(**params) def train_loop(self, train_loader): train_loss: float = 0.0 y_trues, y_preds = [], [] for step, (x_train, y_train) in enumerate(train_loader): y_pred, step_loss = self.train_step(x_train, y_train) train_loss += step_loss y_preds.append(y_pred) y_trues.append(y_train) scores = [] if self.eval_metric is not None: y_preds = tf.concat(y_preds, axis=0) y_trues = tf.concat(y_trues, axis=0) for metric in self.eval_metric: scores.append(metric(y_trues, y_preds)) return train_loss / (step + 1), scores def train_step(self, x_train, y_train): with tf.GradientTape() as tape: y_pred = self.model(x_train, training=True) loss = self.loss_fn(y_train, y_pred) gradients = tape.gradient(loss, self.model.trainable_variables) gradients = [(tf.clip_by_value(grad, -self.max_grad_norm, self.max_grad_norm)) for grad in gradients] _ = self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables)) if self.lr_scheduler is not None: lr = self.lr_scheduler(self.global_step) self.optimizer.learning_rate.assign(lr) else: lr = self.learning_rate self.optimizer.learning_rate.assign(lr) self.global_step.assign_add(1) # logger.info(f'Step: {self.global_step.numpy()}, Loss: {loss}' return y_pred, loss def valid_loop(self, valid_loader): valid_loss: float = 0.0 y_valid_trues, y_valid_preds = [], [] for valid_step, (x_valid, y_valid) in enumerate(valid_loader): y_valid_pred, valid_step_loss = self.valid_step(x_valid, y_valid) valid_loss += valid_step_loss y_valid_trues.append(y_valid) y_valid_preds.append(y_valid_pred) valid_scores = [] if self.eval_metric: y_valid_preds = tf.concat(y_valid_preds, axis=0) y_valid_trues = tf.concat(y_valid_trues, axis=0) for metric in self.eval_metric: valid_scores.append(metric(y_valid_trues, y_valid_preds)) return valid_loss / (valid_step + 1), valid_scores def valid_step(self, x_valid, y_valid): y_valid_pred = self.model(x_valid, training=False) valid_loss = self.loss_fn(y_valid, y_valid_pred) return y_valid_pred, valid_loss def predict(self, test_loader): y_test_trues, y_test_preds = [], [] for x_test, y_test in test_loader: y_test_pred = self.model(x_test, training=False) y_test_preds.append(y_test_pred) y_test_trues.append(y_test) y_test_trues = tf.concat(y_test_trues, axis=0) y_test_preds = tf.concat(y_test_preds, axis=0) return tf.squeeze(y_test_trues, axis=-1), y_test_preds def save_model(self, model_dir, only_pb=True): # save the model if not model_dir.endswith(".keras"): model_dir = f"{model_dir}.keras" self.model.save(model_dir) logger.info(f"Model successfully saved in {model_dir}") if not only_pb: self.model.save_weights(f"{model_dir}.ckpt") logger.info(f"Model weights successfully saved in {model_dir}.ckpt")