KerasTrainer#

class tfts.trainer.KerasTrainer(model: Model | BaseModel, strategy: Strategy | None = None, args: TrainingArguments | None = None, **kwargs: Dict[str, object])[source]#

Bases: BaseTrainer

Keras trainer from tf.keras

Initializes the trainer with the model, loss function, optimizer, and other optional parameters.

Parameters:
  • model – A Keras Model or Sequential instance to train.

  • strategy – Optional distribution strategy for multi-GPU or multi-node training.

  • **kwargs – Additional arguments that are passed to the instance as attributes.

Inherited-members:

Methods

create_accelerator_and_postprocess()

evaluate()

fit(**params)

get_eval_dataloader()

get_inputs(train_dataset)

get_learning_rates()

get_model()

get_strategy_scope()

get_test_dataloader()

get_train_dataloader()

plot(history, true, pred)

predict(x_test)

save_model([output_dir])

train(train_dataset[, valid_dataset, ...])

Trains the model on the provided dataset.

train(train_dataset: DatasetV2 | List[Tensor] | Tuple[Tensor, Tensor], valid_dataset: DatasetV2 | List[Tensor] | Tuple[Tensor, Tensor] | None = None, loss_fn: Callable | Loss | str = 'mse', optimizer: Optimizer | str | Dict = 'adam', epochs: int = 10, batch_size: int = 64, steps_per_epoch: int | None = None, metrics: List[Metric] | List[str] | None = None, callbacks: List[Callback] | None = None, run_eagerly: bool = True, verbose: int = 1, **kwargs: Dict[str, object]) History[source]#

Trains the model on the provided dataset.

Parameters:
  • train_dataset – A tf.data.Dataset or list/tuple of tensors (x_train, y_train).

  • valid_dataset – A tf.data.Dataset or list/tuple of tensors (x_valid, y_valid), optional.

  • loss_fn – A callable or Keras loss function. Default is MeanSquaredError.

  • optimizer – A Keras optimizer instance. Default is Adam with learning rate 0.003.

  • epochs – Number of epochs to train for. Default is 10.

  • batch_size – Number of samples per batch. Default is 64.

  • steps_per_epoch – Number of steps per epoch. Optional.

  • metrics – List of metrics for monitoring during training. Optional.

  • callbacks – List of keras callbacks during training. Optional.

  • run_eagerly – Whether to run eagerly. Default is True.

  • verbose – Verbosity level. Default is 1.

  • **kwargs – Additional keyword arguments for callbacks.

Returns:

A History object containing training logs.