Source code for paddlespeech.t2s.datasets.data_table

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from multiprocessing import Manager
from typing import Any
from typing import Callable
from typing import Dict
from typing import List

import numpy as np
from paddle.io import Dataset


[docs]class DataTable(Dataset): """Dataset to load and convert data for general purpose. Args: data (List[Dict[str, Any]]): Metadata, a list of meta datum, each of which is composed of several fields fields (List[str], optional): Fields to use, if not specified, all the fields in the data are used, by default None converters (Dict[str, Callable], optional): Converters used to process each field, by default None use_cache (bool, optional): Whether to use cache, by default False Raises: ValueError: If there is some field that does not exist in data. ValueError: If there is some field in converters that does not exist in fields. """ def __init__(self, data: List[Dict[str, Any]], fields: List[str]=None, converters: Dict[str, Callable]=None, use_cache: bool=False): # metadata self.data = data assert len(data) > 0, "This dataset has no examples" # peak an example to get existing fields. first_example = self.data[0] fields_in_data = first_example.keys() # check all the requested fields exist if fields is None: self.fields = fields_in_data else: for field in fields: if field not in fields_in_data: raise ValueError( f"The requested field ({field}) is not found" f"in the data. Fields in the data is {fields_in_data}") self.fields = fields # check converters if converters is None: self.converters = {} else: for field in converters.keys(): if field not in self.fields: raise ValueError( f"The converter has a non existing field ({field})") self.converters = converters self.use_cache = use_cache if use_cache: self._initialize_cache() def _initialize_cache(self): self.manager = Manager() self.caches = self.manager.list() self.caches += [None for _ in range(len(self))] def _get_metadata(self, idx: int) -> Dict[str, Any]: """Return a meta-datum given an index.""" return self.data[idx] def _convert(self, meta_datum: Dict[str, Any]) -> Dict[str, Any]: """Convert a meta datum to an example by applying the corresponding converters to each fields requested. Args: meta_datum (Dict[str, Any]): Meta datum Returns: Dict[str, Any]: Converted example """ example = {} for field in self.fields: converter = self.converters.get(field, None) meta_datum_field = meta_datum[field] if converter is not None: converted_field = converter(meta_datum_field) else: converted_field = meta_datum_field example[field] = converted_field return example def __getitem__(self, idx: int) -> Dict[str, Any]: """Get an example given an index. Args: idx (int): Index of the example to get Returns: Dict[str, Any]: A converted example """ if self.use_cache and self.caches[idx] is not None: return self.caches[idx] meta_datum = self._get_metadata(idx) example = self._convert(meta_datum) if self.use_cache: self.caches[idx] = example return example def __len__(self) -> int: """Returns the size of the dataset. Returns ------- int The length of the dataset """ return len(self.data)
[docs]class StarGANv2VCDataTable(DataTable): def __init__(self, data: List[Dict[str, Any]]): super().__init__(data) raw_data = data spk_id_set = list(set([item['spk_id'] for item in raw_data])) data_list_per_class = {} for spk_id in spk_id_set: data_list_per_class[spk_id] = [] for item in raw_data: for spk_id in spk_id_set: if item['spk_id'] == spk_id: data_list_per_class[spk_id].append(item) self.data_list_per_class = data_list_per_class def __getitem__(self, idx: int) -> Dict[str, Any]: """Get an example given an index. Args: idx (int): Index of the example to get Returns: Dict[str, Any]: A converted example """ if self.use_cache and self.caches[idx] is not None: return self.caches[idx] data = self._get_metadata(idx) # 裁剪放到 batch_fn 里面 # 返回一个字典 """ {'utt_id': 'p225_111', 'spk_id': '1', 'speech': 'path of *.npy'} """ ref_data = random.choice(self.data) ref_label = ref_data['spk_id'] ref_data_2 = random.choice(self.data_list_per_class[ref_label]) # mel_tensor, label, ref_mel_tensor, ref2_mel_tensor, ref_label new_example = { 'utt_id': data['utt_id'], 'mel': np.load(data['speech']), 'label': int(data['spk_id']), 'ref_mel': np.load(ref_data['speech']), 'ref_mel_2': np.load(ref_data_2['speech']), 'ref_label': int(ref_label) } if self.use_cache: self.caches[idx] = new_example return new_example