KerasTrainer#
- class tfts.trainer.KerasTrainer(model: Model | BaseModel, strategy: Strategy | None = None, args: TrainingArguments | None = None, **kwargs: Dict[str, object])[source]#
Bases:
BaseTrainerKeras trainer from tf.keras
Initializes the trainer with the model, loss function, optimizer, and other optional parameters.
- Parameters:
model – A Keras Model or Sequential instance to train.
strategy – Optional distribution strategy for multi-GPU or multi-node training.
**kwargs – Additional arguments that are passed to the instance as attributes.
- Inherited-members:
Methods
create_accelerator_and_postprocess()evaluate()fit(**params)get_eval_dataloader()get_inputs(train_dataset)get_learning_rates()get_model()get_strategy_scope()get_test_dataloader()get_train_dataloader()plot(history, true, pred)predict(x_test)save_model([output_dir])train(train_dataset[, valid_dataset, ...])Trains the model on the provided dataset.
- train(train_dataset: DatasetV2 | List[Tensor] | Tuple[Tensor, Tensor], valid_dataset: DatasetV2 | List[Tensor] | Tuple[Tensor, Tensor] | None = None, loss_fn: Callable | Loss | str = 'mse', optimizer: Optimizer | str | Dict = 'adam', epochs: int = 10, batch_size: int = 64, steps_per_epoch: int | None = None, metrics: List[Metric] | List[str] | None = None, callbacks: List[Callback] | None = None, run_eagerly: bool = True, verbose: int = 1, **kwargs: Dict[str, object]) History[source]#
Trains the model on the provided dataset.
- Parameters:
train_dataset – A tf.data.Dataset or list/tuple of tensors (x_train, y_train).
valid_dataset – A tf.data.Dataset or list/tuple of tensors (x_valid, y_valid), optional.
loss_fn – A callable or Keras loss function. Default is MeanSquaredError.
optimizer – A Keras optimizer instance. Default is Adam with learning rate 0.003.
epochs – Number of epochs to train for. Default is 10.
batch_size – Number of samples per batch. Default is 64.
steps_per_epoch – Number of steps per epoch. Optional.
metrics – List of metrics for monitoring during training. Optional.
callbacks – List of keras callbacks during training. Optional.
run_eagerly – Whether to run eagerly. Default is True.
verbose – Verbosity level. Default is 1.
**kwargs – Additional keyword arguments for callbacks.
- Returns:
A History object containing training logs.