Source code for paddlespeech.s2t.models.st_interface

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
"""ST Interface module."""
from .asr_interface import ASRInterface
from paddlespeech.s2t.utils.dynamic_import import dynamic_import


[docs]class STInterface(ASRInterface): """ST Interface model implementation. NOTE: This class is inherited from ASRInterface to enable joint translation and recognition when performing multi-task learning with the ASR task. """
[docs] def translate(self, x, trans_args, char_list=None, rnnlm=None, ensemble_models=[]): """Recognize x for evaluation. :param ndarray x: input acouctic feature (B, T, D) or (T, D) :param namespace trans_args: argment namespace contraining options :param list char_list: list of characters :param paddle.nn.Layer rnnlm: language model module :return: N-best decoding results :rtype: list """ raise NotImplementedError("translate method is not implemented")
[docs] def translate_batch(self, x, trans_args, char_list=None, rnnlm=None): """Beam search implementation for batch. :param paddle.Tensor x: encoder hidden state sequences (B, Tmax, Henc) :param namespace trans_args: argument namespace containing options :param list char_list: list of characters :param paddle.nn.Layer rnnlm: language model module :return: N-best decoding results :rtype: list """ raise NotImplementedError("Batch decoding is not supported yet.")
predefined_st = { "transformer": "paddlespeech.s2t.models.u2_st:U2STModel", }
[docs]def dynamic_import_st(module): """Import ST models dynamically. Args: module (str): module_name:class_name or alias in `predefined_st` Returns: type: ST class """ model_class = dynamic_import(module, predefined_st) assert issubclass(model_class, STInterface), f"{module} does not implement STInterface" return model_class