Source code for paddlespeech.vector.io.dataset

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from dataclasses import fields

from paddle.io import Dataset
from paddleaudio.backends import soundfile_load as load_audio
from paddleaudio.compliance.librosa import melspectrogram

from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()

# the audio meta info in the vector CSVDataset
# utt_id: the utterance segment name
# duration: utterance segment time
# wav: utterance file path
# start: start point in the original wav file
# stop: stop point in the original wav file
# label: the utterance segment's label id


[docs]@dataclass class meta_info: """the audio meta info in the vector CSVDataset Args: utt_id (str): the utterance segment name duration (float): utterance segment time wav (str): utterance file path start (int): start point in the original wav file stop (int): stop point in the original wav file lab_id (str): the utterance segment's label id """ utt_id: str duration: float wav: str start: int stop: int label: str
# csv dataset support feature type # raw: return the pcm data sample point # melspectrogram: fbank feature feat_funcs = { 'raw': None, 'melspectrogram': melspectrogram, }
[docs]class CSVDataset(Dataset): def __init__(self, csv_path, label2id_path=None, config=None, random_chunk=True, feat_type: str="raw", n_train_snts: int=-1, **kwargs): """Implement the CSV Dataset Args: csv_path (str): csv dataset file path label2id_path (str): the utterance label to integer id map file path config (CfgNode): yaml config feat_type (str): dataset feature type. if it is raw, it return pcm data. n_train_snts (int): select the n_train_snts sample from the dataset. if n_train_snts = -1, dataset will load all the sample. Default value is -1. kwargs : feature type args """ super().__init__() self.csv_path = csv_path self.label2id_path = label2id_path self.config = config self.random_chunk = random_chunk self.feat_type = feat_type self.n_train_snts = n_train_snts self.feat_config = kwargs self.id2label = {} self.label2id = {} self.data = self.load_data_csv() self.load_speaker_to_label()
[docs] def load_data_csv(self): """Load the csv dataset content and store them in the data property the csv dataset's format has six fields, that is audio_id or utt_id, audio duration, segment start point, segment stop point and utterance label. Note in training period, the utterance label must has a map to integer id in label2id_path Returns: list: the csv data with meta_info type """ data = [] with open(self.csv_path, 'r') as rf: for line in rf.readlines()[1:]: audio_id, duration, wav, start, stop, spk_id = line.strip( ).split(',') data.append( meta_info(audio_id, float(duration), wav, int(start), int(stop), spk_id)) if self.n_train_snts > 0: sample_num = min(self.n_train_snts, len(data)) data = data[0:sample_num] return data
[docs] def load_speaker_to_label(self): """Load the utterance label map content. In vector domain, we call the utterance label as speaker label. The speaker label is real speaker label in speaker verification domain, and in language identification is language label. """ if not self.label2id_path: logger.warning("No speaker id to label file") return with open(self.label2id_path, 'r') as f: for line in f.readlines(): label_name, label_id = line.strip().split(' ') self.label2id[label_name] = int(label_id) self.id2label[int(label_id)] = label_name
[docs] def convert_to_record(self, idx: int): """convert the dataset sample to training record the CSV Dataset Args: idx (int) : the request index in all the dataset """ sample = self.data[idx] record = {} # To show all fields in a namedtuple: `type(sample)._fields` for field in fields(sample): record[field.name] = getattr(sample, field.name) waveform, sr = load_audio(record['wav']) # random select a chunk audio samples from the audio if self.config and self.config.random_chunk: num_wav_samples = waveform.shape[0] num_chunk_samples = int(self.config.chunk_duration * sr) start = random.randint(0, num_wav_samples - num_chunk_samples - 1) stop = start + num_chunk_samples else: start = record['start'] stop = record['stop'] # we only return the waveform as feat waveform = waveform[start:stop] # all availabel feature type is in feat_funcs assert self.feat_type in feat_funcs.keys(), \ f"Unknown feat_type: {self.feat_type}, it must be one in {list(feat_funcs.keys())}" feat_func = feat_funcs[self.feat_type] feat = feat_func( waveform, sr=sr, **self.feat_config) if feat_func else waveform record.update({'feat': feat}) if self.label2id: record.update({'label': self.label2id[record['label']]}) return record
def __getitem__(self, idx): """Return the specific index sample Args: idx (int) : the request index in all the dataset """ return self.convert_to_record(idx) def __len__(self): """Return the dataset length Returns: int: the length num of the dataset """ return len(self.data)