Seq2seqKerasTrainer#

class tfts.trainer.Seq2seqKerasTrainer(*args, **kwargs)[source]#

Bases: KerasTrainer

As the transformers forum mentioned: https://discuss.huggingface.co/t/trainer-vs-seq2seqtrainer/3145/2 Seq2SeqTrainer is mostly about predict_with_generate.

Initializes the trainer with the model, loss function, optimizer, and other optional parameters.

Parameters:
  • model – A Keras Model or Sequential instance to train.

  • strategy – Optional distribution strategy for multi-GPU or multi-node training.

  • **kwargs – Additional arguments that are passed to the instance as attributes.

Inherited-members:

Methods

create_accelerator_and_postprocess()

evaluate()

fit(**params)

get_eval_dataloader()

get_inputs(train_dataset)

get_learning_rates()

get_model()

get_strategy_scope()

get_test_dataloader()

get_train_dataloader()

plot(history, true, pred)

predict(x_test)

save_model([output_dir])

train(train_dataset[, valid_dataset, ...])

Trains the model on the provided dataset.