Source code for paddlespeech.cli.base_commands

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from typing import List

import numpy
from prettytable import PrettyTable

from ..resource import CommonTaskResource
from .entry import commands
from .utils import cli_register
from .utils import explicit_command_register
from .utils import get_command

__all__ = ['BaseCommand', 'HelpCommand', 'StatsCommand']


[docs]@cli_register(name='paddlespeech') class BaseCommand:
[docs] def execute(self, argv: List[str]) -> bool: help = get_command('paddlespeech.help') return help().execute(argv)
[docs]@cli_register(name='paddlespeech.help', description='Show help for commands.') class HelpCommand:
[docs] def execute(self, argv: List[str]) -> bool: msg = 'Usage:\n' msg += ' paddlespeech <command> <options>\n\n' msg += 'Commands:\n' for command, detail in commands['paddlespeech'].items(): if command.startswith('_'): continue if '_description' not in detail: continue msg += ' {:<15} {}\n'.format(command, detail['_description']) print(msg) return True
@cli_register( name='paddlespeech.version', description='Show version and commit id of current package.') class VersionCommand: def execute(self, argv: List[str]) -> bool: try: from .. import __version__ version = __version__ except ImportError: version = 'Not an official release' try: from .. import __commit__ commit_id = __commit__ except ImportError: commit_id = 'Not found' msg = 'Package Version:\n' msg += ' {}\n\n'.format(version) msg += 'Commit ID:\n' msg += ' {}\n\n'.format(commit_id) print(msg) return True model_name_format = { 'asr': 'Model-Size-Code Switch-Multilingual-Language-Sample Rate', 'cls': 'Model-Sample Rate', 'st': 'Model-Source language-Target language', 'text': 'Model-Task-Language', 'tts': 'Model-Language', 'vector': 'Model-Sample Rate', 'ssl': 'Model-Language-Sample Rate', 'whisper': 'Model-Language-Sample Rate' }
[docs]@cli_register( name='paddlespeech.stats', description='Get speech tasks support models list.') class StatsCommand: def __init__(self): self.parser = argparse.ArgumentParser( prog='paddlespeech.stats', add_help=True) self.task_choices = [ 'asr', 'cls', 'st', 'text', 'tts', 'vector', 'kws', 'ssl', 'whisper' ] self.parser.add_argument( '--task', type=str, default='asr', choices=self.task_choices, help='Choose speech task.', required=True)
[docs] def show_support_models(self, pretrained_models: dict): fields = model_name_format[self.task].split("-") table = PrettyTable(fields) for key in pretrained_models: line = key.split("-") if self.task == "asr" and len(line) < len(fields): for i in range(len(line), len(fields)): line.append("-") if "codeswitch" in key: line[3], line[1] = line[1].split("_")[0], line[1].split( "_")[1:] elif "multilingual" in key: line[4], line[1] = line[1].split("_")[0], line[1].split( "_")[1:] tmp = numpy.array(line) idx = [0, 5, 3, 4, 1, 2] line = tmp[idx] table.add_row(line) print(table)
[docs] def execute(self, argv: List[str]) -> bool: parser_args = self.parser.parse_args(argv) self.task = parser_args.task if self.task not in self.task_choices: print("Please input correct speech task, choices = " + str( self.task_choices)) return pretrained_models = CommonTaskResource(task=self.task).pretrained_models try: print( "Here is the list of {} pretrained models released by PaddleSpeech that can be used by command line and python API" .format(self.task.upper())) self.show_support_models(pretrained_models) return True except BaseException: print("Failed to get the list of {} pretrained models.".format( self.task.upper())) return False
# Dynamic import when running specific command _commands = { 'asr': ['Speech to text infer command.', 'ASRExecutor'], 'cls': ['Audio classification infer command.', 'CLSExecutor'], 'st': ['Speech translation infer command.', 'STExecutor'], 'text': ['Text command.', 'TextExecutor'], 'tts': ['Text to Speech infer command.', 'TTSExecutor'], 'vector': ['Speech to vector embedding infer command.', 'VectorExecutor'], 'kws': ['Keyword Spotting infer command.', 'KWSExecutor'], 'ssl': ['Self-Supervised Learning Pretrained model infer command.', 'SSLExecutor'], 'whisper': [ 'Whisper model for speech to text or translate speech to English.', 'WhisperExecutor' ] } for com, info in _commands.items(): explicit_command_register( name='paddlespeech.{}'.format(com), description=info[0], cls='paddlespeech.cli.{}.{}'.format(com, info[1]))