paddlespeech.s2t.training.extensions.plot module

class paddlespeech.s2t.training.extensions.plot.PlotAttentionReport(att_vis_fn, data, outdir, converter, transform, device, reverse=False, ikey='input', iaxis=0, okey='output', oaxis=0, subsampling_factor=1)[source]

Bases: Extension

Plot attention reporter.

Args:
att_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_attentions):

Function of attention visualization.

data (list[tuple(str, dict[str, list[Any]])]): List json utt key items. outdir (str): Directory to save figures. converter (espnet.asr.*_backend.asr.CustomConverter):

Function to convert data.

device (int | torch.device): Device. reverse (bool): If True, input and output length are reversed. ikey (str): Key to access input

(for ASR/ST ikey="input", for MT ikey="output".)

iaxis (int): Dimension to access input

(for ASR/ST iaxis=0, for MT iaxis=1.)

okey (str): Key to access output

(for ASR/ST okey="input", MT okay="output".)

oaxis (int): Dimension to access output

(for ASR/ST oaxis=0, for MT oaxis=0.)

subsampling_factor (int): subsampling factor in encoder

Attributes:
default_name

Default name of the extension, class name by default.

name

Methods

__call__(trainer)

Plot and save image file of att_ws matrix.

draw_attention_plot(att_w)

Plot the att_w matrix.

draw_han_plot(att_w)

Plot the att_w matrix for hierarchical attention.

finalize(trainer)

Action that is executed when training is done.

get_attention_weights()

Return attention weights.

initialize(trainer)

Action that is executed once to get the corect trainer state.

log_attentions(logger, step)

Add image files of att_ws matrix to the tensorboard.

on_error(trainer, exc, tb)

Handles the error raised during training before finalization.

trim_attention_weight(uttid, att_w)

Transform attention matrix with regard to self.reverse.

draw_attention_plot(att_w)[source]

Plot the att_w matrix.

Returns:

matplotlib.pyplot: pyplot object with attention matrix image.

draw_han_plot(att_w)[source]

Plot the att_w matrix for hierarchical attention.

Returns:

matplotlib.pyplot: pyplot object with attention matrix image.

get_attention_weights()[source]

Return attention weights.

Returns:
numpy.ndarray: attention weights. float. Its shape would be

differ from backend. * pytorch-> 1) multi-head case => (B, H, Lmax, Tmax), 2)

other case => (B, Lmax, Tmax).

  • chainer-> (B, Lmax, Tmax)

log_attentions(logger, step)[source]

Add image files of att_ws matrix to the tensorboard.

trim_attention_weight(uttid, att_w)[source]

Transform attention matrix with regard to self.reverse.

class paddlespeech.s2t.training.extensions.plot.PlotCTCReport(ctc_vis_fn, data, outdir, converter, transform, device, reverse=False, ikey='input', iaxis=0, okey='output', oaxis=0, subsampling_factor=1)[source]

Bases: Extension

Plot CTC reporter.

Args:
ctc_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_ctc_probs):

Function of CTC visualization.

data (list[tuple(str, dict[str, list[Any]])]): List json utt key items. outdir (str): Directory to save figures. converter (espnet.asr.*_backend.asr.CustomConverter):

Function to convert data.

device (int | torch.device): Device. reverse (bool): If True, input and output length are reversed. ikey (str): Key to access input

(for ASR/ST ikey="input", for MT ikey="output".)

iaxis (int): Dimension to access input

(for ASR/ST iaxis=0, for MT iaxis=1.)

okey (str): Key to access output

(for ASR/ST okey="input", MT okay="output".)

oaxis (int): Dimension to access output

(for ASR/ST oaxis=0, for MT oaxis=0.)

subsampling_factor (int): subsampling factor in encoder

Attributes:
default_name

Default name of the extension, class name by default.

name

Methods

__call__(trainer)

Plot and save image file of ctc prob.

draw_ctc_plot(ctc_prob)

Plot the ctc_prob matrix.

finalize(trainer)

Action that is executed when training is done.

get_ctc_probs()

Return CTC probs.

initialize(trainer)

Action that is executed once to get the corect trainer state.

log_ctc_probs(logger, step)

Add image files of ctc probs to the tensorboard.

on_error(trainer, exc, tb)

Handles the error raised during training before finalization.

trim_ctc_prob(uttid, prob)

Trim CTC posteriors accoding to input lengths.

draw_ctc_plot(ctc_prob)[source]

Plot the ctc_prob matrix.

Returns:

matplotlib.pyplot: pyplot object with CTC prob matrix image.

get_ctc_probs()[source]

Return CTC probs.

Returns:
numpy.ndarray: CTC probs. float. Its shape would be

differ from backend. (B, Tmax, vocab).

log_ctc_probs(logger, step)[source]

Add image files of ctc probs to the tensorboard.

trim_ctc_prob(uttid, prob)[source]

Trim CTC posteriors accoding to input lengths.