paddlespeech.s2t.training.extensions.plot module
- class paddlespeech.s2t.training.extensions.plot.PlotAttentionReport(att_vis_fn, data, outdir, converter, transform, device, reverse=False, ikey='input', iaxis=0, okey='output', oaxis=0, subsampling_factor=1)[source]
Bases:
Extension
Plot attention reporter.
- Args:
- att_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_attentions):
Function of attention visualization.
data (list[tuple(str, dict[str, list[Any]])]): List json utt key items. outdir (str): Directory to save figures. converter (espnet.asr.*_backend.asr.CustomConverter):
Function to convert data.
device (int | torch.device): Device. reverse (bool): If True, input and output length are reversed. ikey (str): Key to access input
(for ASR/ST ikey="input", for MT ikey="output".)
- iaxis (int): Dimension to access input
(for ASR/ST iaxis=0, for MT iaxis=1.)
- okey (str): Key to access output
(for ASR/ST okey="input", MT okay="output".)
- oaxis (int): Dimension to access output
(for ASR/ST oaxis=0, for MT oaxis=0.)
subsampling_factor (int): subsampling factor in encoder
- Attributes:
default_name
Default name of the extension, class name by default.
- name
Methods
__call__
(trainer)Plot and save image file of att_ws matrix.
draw_attention_plot
(att_w)Plot the att_w matrix.
draw_han_plot
(att_w)Plot the att_w matrix for hierarchical attention.
finalize
(trainer)Action that is executed when training is done.
Return attention weights.
initialize
(trainer)Action that is executed once to get the corect trainer state.
log_attentions
(logger, step)Add image files of att_ws matrix to the tensorboard.
on_error
(trainer, exc, tb)Handles the error raised during training before finalization.
trim_attention_weight
(uttid, att_w)Transform attention matrix with regard to self.reverse.
- draw_attention_plot(att_w)[source]
Plot the att_w matrix.
- Returns:
matplotlib.pyplot: pyplot object with attention matrix image.
- draw_han_plot(att_w)[source]
Plot the att_w matrix for hierarchical attention.
- Returns:
matplotlib.pyplot: pyplot object with attention matrix image.
- class paddlespeech.s2t.training.extensions.plot.PlotCTCReport(ctc_vis_fn, data, outdir, converter, transform, device, reverse=False, ikey='input', iaxis=0, okey='output', oaxis=0, subsampling_factor=1)[source]
Bases:
Extension
Plot CTC reporter.
- Args:
- ctc_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_ctc_probs):
Function of CTC visualization.
data (list[tuple(str, dict[str, list[Any]])]): List json utt key items. outdir (str): Directory to save figures. converter (espnet.asr.*_backend.asr.CustomConverter):
Function to convert data.
device (int | torch.device): Device. reverse (bool): If True, input and output length are reversed. ikey (str): Key to access input
(for ASR/ST ikey="input", for MT ikey="output".)
- iaxis (int): Dimension to access input
(for ASR/ST iaxis=0, for MT iaxis=1.)
- okey (str): Key to access output
(for ASR/ST okey="input", MT okay="output".)
- oaxis (int): Dimension to access output
(for ASR/ST oaxis=0, for MT oaxis=0.)
subsampling_factor (int): subsampling factor in encoder
- Attributes:
default_name
Default name of the extension, class name by default.
- name
Methods
__call__
(trainer)Plot and save image file of ctc prob.
draw_ctc_plot
(ctc_prob)Plot the ctc_prob matrix.
finalize
(trainer)Action that is executed when training is done.
Return CTC probs.
initialize
(trainer)Action that is executed once to get the corect trainer state.
log_ctc_probs
(logger, step)Add image files of ctc probs to the tensorboard.
on_error
(trainer, exc, tb)Handles the error raised during training before finalization.
trim_ctc_prob
(uttid, prob)Trim CTC posteriors accoding to input lengths.
- draw_ctc_plot(ctc_prob)[source]
Plot the ctc_prob matrix.
- Returns:
matplotlib.pyplot: pyplot object with CTC prob matrix image.