Trainer#
- class tfts.trainer.Trainer(model, strategy: Strategy | None = None, **kwargs: Dict[str, Any])[source]#
Bases:
objectCustom trainer for tensorflow with support for CPU, GPU, and multi-GPU.
- Inherited-members:
Methods
fit(**params)predict(test_loader)save_model(model_dir[, only_pb])train(train_loader[, valid_loader, loss_fn, ...])Trains the model using the provided data loaders.
train_loop(train_loader)train_step(x_train, y_train)valid_loop(valid_loader)valid_step(x_valid, y_valid)- train(train_loader: ~tensorflow.python.data.ops.dataset_ops.DatasetV2 | ~typing.Generator, valid_loader: ~tensorflow.python.data.ops.dataset_ops.DatasetV2 | ~typing.Generator | None = None, loss_fn: ~typing.Callable = <LossFunctionWrapper(<function mean_squared_error>, kwargs={})>, optimizer: ~keras.src.optimizers.optimizer.Optimizer = <keras.src.optimizers.adam.Adam object>, lr_scheduler: ~keras.src.optimizers.schedules.learning_rate_schedule.LearningRateSchedule | None = None, epochs: int = 10, learning_rate: float = 0.0003, verbose: int = 1, eval_metric: ~typing.Callable | ~typing.List[~typing.Callable] | None = None, model_dir: str | None = None, use_ema: bool = False, stop_no_improve_epochs: int | None = None, max_grad_norm: float = 5.0, transform: ~typing.Callable | None = None) None[source]#
Trains the model using the provided data loaders.
- Parameters:
train_loader (Union[tf.data.Dataset, Generator]) – The training data loader, which can be a tf.data.Dataset or a Python generator.
valid_loader (Union[tf.data.Dataset, Generator, None], optional) – The validation data loader, by default None.
epochs (int, optional) – The number of epochs to train the model, by default 10.
learning_rate (float, optional) – The initial learning rate for the optimizer, by default 3e-4.
verbose (int, optional) – The verbosity level (0 = silent, 1 = progress bar, 2 = one line per epoch), by default 1.
eval_metric (Union[Callable, List[Callable], None], optional) – The evaluation metric(s) to use for validation, by default None.
model_dir (Optional[str], optional) – The directory to save the model weights, by default “../weights”.
use_ema (bool, optional) – Whether to use exponential moving average (EMA) for the model weights, by default False.
stop_no_improve_epochs (Optional[int], optional) – If provided, training will stop if the validation metric does not improve for the specified number of epochs, by default None.
max_grad_norm (float, optional) – the max gradient while backprop.
transform (Optional[Callable], optional) – A function to transform the data before feeding it to the model, by default None.