paddlespeech.t2s.training.extensions.snapshot module

class paddlespeech.t2s.training.extensions.snapshot.Snapshot(max_size: int = 5, snapshot_on_error: bool = False)[source]

Bases: Extension

An extension to make snapshot of the updater object inside the trainer. It is done by calling the updater's save method.

An Updater save its state_dict by default, which contains the updater state, (i.e. epoch and iteration) and all the model parameters and optimizer states. If the updater inside the trainer subclasses StandardUpdater, everything is good to go.

Arsg:

checkpoint_dir (Union[str, Path]): The directory to save checkpoints into.

Attributes:
name

Methods

__call__(trainer)

Main action of the extention.

finalize(trainer)

Action that is executed when training is done.

full()

Whether the number of snapshots it keeps track of is greater than the max_size.

initialize(trainer)

Setting up this extention.

on_error(trainer, exc, tb)

Handles the error raised during training before finalization.

save_checkpoint_and_update(trainer)

Saving new snapshot and remove the oldest snapshot if needed.

default_name = 'snapshot'
full()[source]

Whether the number of snapshots it keeps track of is greater than the max_size.

initialize(trainer: Trainer)[source]

Setting up this extention.

on_error(trainer, exc, tb)[source]

Handles the error raised during training before finalization.

priority = -100
save_checkpoint_and_update(trainer: Trainer)[source]

Saving new snapshot and remove the oldest snapshot if needed.

trigger = (1, 'epoch')
paddlespeech.t2s.training.extensions.snapshot.load_records(records_fp)[source]

Load record files (json lines.)