Source code for paddlespeech.t2s.exps.transformer_tts.synthesize_e2e

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path

import numpy as np
import paddle
import soundfile as sf
import yaml
from yacs.config import CfgNode

from paddlespeech.t2s.frontend.en_frontend import English
from paddlespeech.t2s.models.transformer_tts import TransformerTTS
from paddlespeech.t2s.models.transformer_tts import TransformerTTSInference
from paddlespeech.t2s.models.waveflow import ConditionalWaveFlow
from paddlespeech.t2s.modules.normalizer import ZScore
from paddlespeech.t2s.utils import layer_tools


[docs]def evaluate(args, acoustic_model_config, vocoder_config): # dataloader has been too verbose logging.getLogger("DataLoader").disabled = True # construct dataset for evaluation sentences = [] with open(args.text, 'rt') as f: for line in f: line_list = line.strip().split() utt_id = line_list[0] sentence = " ".join(line_list[1:]) sentences.append((utt_id, sentence)) with open(args.phones_dict, "r") as f: phn_id = [line.strip().split() for line in f.readlines()] vocab_size = len(phn_id) phone_id_map = {} for phn, id in phn_id: phone_id_map[phn] = int(id) print("vocab_size:", vocab_size) odim = acoustic_model_config.n_mels model = TransformerTTS( idim=vocab_size, odim=odim, **acoustic_model_config["model"]) model.set_state_dict( paddle.load(args.transformer_tts_checkpoint)["main_params"]) model.eval() # remove ".pdparams" in waveflow_checkpoint vocoder_checkpoint_path = args.waveflow_checkpoint[:-9] if args.waveflow_checkpoint.endswith( ".pdparams") else args.waveflow_checkpoint vocoder = ConditionalWaveFlow.from_pretrained(vocoder_config, vocoder_checkpoint_path) layer_tools.recursively_remove_weight_norm(vocoder) vocoder.eval() print("model done!") frontend = English() print("frontend done!") stat = np.load(args.transformer_tts_stat) mu, std = stat mu = paddle.to_tensor(mu) std = paddle.to_tensor(std) transformer_tts_normalizer = ZScore(mu, std) transformer_tts_inference = TransformerTTSInference( transformer_tts_normalizer, model) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) for utt_id, sentence in sentences: phones = frontend.phoneticize(sentence) # remove start_symbol and end_symbol phones = phones[1:-1] phones = [phn for phn in phones if not phn.isspace()] phones = [phn if phn in phone_id_map else "," for phn in phones] phone_ids = [phone_id_map[phn] for phn in phones] with paddle.no_grad(): mel = transformer_tts_inference(paddle.to_tensor(phone_ids)) # mel shape is (T, feats) and waveflow's input shape is (batch, feats, T) mel = mel.unsqueeze(0).transpose([0, 2, 1]) # wavflow's output shape is (B, T) wav = vocoder.infer(mel)[0] sf.write( str(output_dir / (utt_id + ".wav")), wav.numpy(), samplerate=acoustic_model_config.fs) print(f"{utt_id} done!")
[docs]def main(): # parse args and config and redirect to train_sp parser = argparse.ArgumentParser( description="Synthesize with transformer tts & waveflow.") parser.add_argument( "--transformer-tts-config", type=str, help="transformer tts config file.") parser.add_argument( "--transformer-tts-checkpoint", type=str, help="transformer tts checkpoint to load.") parser.add_argument( "--transformer-tts-stat", type=str, help="mean and standard deviation used to normalize spectrogram when training transformer tts." ) parser.add_argument( "--waveflow-config", type=str, help="waveflow config file.") # not normalize when training waveflow parser.add_argument( "--waveflow-checkpoint", type=str, help="waveflow checkpoint to load.") parser.add_argument( "--phones-dict", type=str, default="phone_id_map.txt", help="phone vocabulary file.") parser.add_argument( "--text", type=str, help="text to synthesize, a 'utt_id sentence' pair per line.") parser.add_argument("--output-dir", type=str, help="output dir.") parser.add_argument( "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") args = parser.parse_args() if args.ngpu == 0: paddle.set_device("cpu") elif args.ngpu > 0: paddle.set_device("gpu") else: print("ngpu should >= 0 !") with open(args.transformer_tts_config) as f: transformer_tts_config = CfgNode(yaml.safe_load(f)) with open(args.waveflow_config) as f: waveflow_config = CfgNode(yaml.safe_load(f)) print("========Args========") print(yaml.safe_dump(vars(args))) print("========Config========") print(transformer_tts_config) print(waveflow_config) evaluate(args, transformer_tts_config, waveflow_config)
if __name__ == "__main__": main()