paddlespeech.t2s.datasets.sampler module
- class paddlespeech.t2s.datasets.sampler.ErnieSATSampler(dataset, batch_size, num_replicas=None, rank=None, shuffle=False, drop_last=False)[source]
Bases:
BatchSampler
Sampler that restricts data loading to a subset of the dataset. In such case, each process can pass a DistributedBatchSampler instance as a DataLoader sampler, and load a subset of the original dataset that is exclusive to it. .. note:
Dataset is assumed to be of constant size.
- Args:
- dataset(paddle.io.Dataset): this could be a paddle.io.Dataset implement
or other python object which implemented __len__ for BatchSampler to get sample number of data source.
batch_size(int): sample indice number in a mini-batch indices. num_replicas(int, optional): porcess number in distributed training.
If
num_replicas
is None,num_replicas
will be retrieved frompaddle.distributed.ParallenEnv
. Default None.- rank(int, optional): the rank of the current process among
num_replicas
processes. If
rank
is None,rank
is retrieved frompaddle.distributed.ParallenEnv
. Default None.- shuffle(bool): whther to shuffle indices order before genrating
batch indices. Default False.
- drop_last(bool): whether drop the last incomplete batch dataset size
is not divisible by the batch size. Default False
- Examples:
Methods
set_epoch
(epoch)Sets the epoch number. When
shuffle=True
, this number is used as seeds of random numbers. By default, users may not set this, all replicas (workers) use a different random ordering for each epoch. If set same number at each epoch, this sampler will yield the same ordering at all epoches. Arguments: epoch (int): Epoch number. Examples: .. code-block:: python.- set_epoch(epoch)[source]
Sets the epoch number. When
shuffle=True
, this number is used as seeds of random numbers. By default, users may not set this, all replicas (workers) use a different random ordering for each epoch. If set same number at each epoch, this sampler will yield the same ordering at all epoches. Arguments:epoch (int): Epoch number.
- Examples:
import numpy as np from paddle.io import Dataset, DistributedBatchSampler # init with dataset class RandomDataset(Dataset): def __init__(self, num_samples): self.num_samples = num_samples def __getitem__(self, idx): image = np.random.random([784]).astype('float32') label = np.random.randint(0, 9, (1, )).astype('int64') return image, label def __len__(self): return self.num_samples dataset = RandomDataset(100) sampler = DistributedBatchSampler(dataset, batch_size=64) for epoch in range(10): sampler.set_epoch(epoch)