Seq2seqKerasTrainer#
- class tfts.trainer.Seq2seqKerasTrainer(*args, **kwargs)[source]#
Bases:
KerasTrainerAs the transformers forum mentioned: https://discuss.huggingface.co/t/trainer-vs-seq2seqtrainer/3145/2 Seq2SeqTrainer is mostly about predict_with_generate.
Initializes the trainer with the model, loss function, optimizer, and other optional parameters.
- Parameters:
model – A Keras Model or Sequential instance to train.
strategy – Optional distribution strategy for multi-GPU or multi-node training.
**kwargs – Additional arguments that are passed to the instance as attributes.
- Inherited-members:
Methods
create_accelerator_and_postprocess()evaluate()fit(**params)get_eval_dataloader()get_inputs(train_dataset)get_learning_rates()get_model()get_strategy_scope()get_test_dataloader()get_train_dataloader()plot(history, true, pred)predict(x_test)save_model([output_dir])train(train_dataset[, valid_dataset, ...])Trains the model on the provided dataset.