Trainer#

class tfts.trainer.Trainer(model, strategy: Strategy | None = None, **kwargs: Dict[str, Any])[source]#

Bases: object

Custom trainer for tensorflow with support for CPU, GPU, and multi-GPU.

Inherited-members:

Methods

fit(**params)

predict(test_loader)

save_model(model_dir[, only_pb])

train(train_loader[, valid_loader, loss_fn, ...])

Trains the model using the provided data loaders.

train_loop(train_loader)

train_step(x_train, y_train)

valid_loop(valid_loader)

valid_step(x_valid, y_valid)

train(train_loader: ~tensorflow.python.data.ops.dataset_ops.DatasetV2 | ~typing.Generator, valid_loader: ~tensorflow.python.data.ops.dataset_ops.DatasetV2 | ~typing.Generator | None = None, loss_fn: ~typing.Callable = <LossFunctionWrapper(<function mean_squared_error>, kwargs={})>, optimizer: ~keras.src.optimizers.optimizer.Optimizer = <keras.src.optimizers.adam.Adam object>, lr_scheduler: ~keras.src.optimizers.schedules.learning_rate_schedule.LearningRateSchedule | None = None, epochs: int = 10, learning_rate: float = 0.0003, verbose: int = 1, eval_metric: ~typing.Callable | ~typing.List[~typing.Callable] | None = None, model_dir: str | None = None, use_ema: bool = False, stop_no_improve_epochs: int | None = None, max_grad_norm: float = 5.0, transform: ~typing.Callable | None = None) None[source]#

Trains the model using the provided data loaders.

Parameters:
  • train_loader (Union[tf.data.Dataset, Generator]) – The training data loader, which can be a tf.data.Dataset or a Python generator.

  • valid_loader (Union[tf.data.Dataset, Generator, None], optional) – The validation data loader, by default None.

  • epochs (int, optional) – The number of epochs to train the model, by default 10.

  • learning_rate (float, optional) – The initial learning rate for the optimizer, by default 3e-4.

  • verbose (int, optional) – The verbosity level (0 = silent, 1 = progress bar, 2 = one line per epoch), by default 1.

  • eval_metric (Union[Callable, List[Callable], None], optional) – The evaluation metric(s) to use for validation, by default None.

  • model_dir (Optional[str], optional) – The directory to save the model weights, by default “../weights”.

  • use_ema (bool, optional) – Whether to use exponential moving average (EMA) for the model weights, by default False.

  • stop_no_improve_epochs (Optional[int], optional) – If provided, training will stop if the validation metric does not improve for the specified number of epochs, by default None.

  • max_grad_norm (float, optional) – the max gradient while backprop.

  • transform (Optional[Callable], optional) – A function to transform the data before feeding it to the model, by default None.