# Copyright 2022 Sean Robertson
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import warnings
from collections import Counter
from itertools import islice
from typing import (
Collection,
Container,
Dict,
Hashable,
Iterable,
Iterator,
List,
Optional,
overload,
Sequence,
Set,
Sized,
Tuple,
TypeVar,
Union,
)
from typing_extensions import Literal, get_args
import param
import numpy as np
import torch
from . import config, argcheck
from ._datasets import (
ContextWindowDataParams,
ContextWindowDataSet,
LangDataParams,
LangDataSet,
SpectDataParams,
SpectDataSet,
)
try:
_BaseSampler = torch.utils.data.sampler.Sampler[int]
except TypeError:
_BaseSampler = torch.utils.data.sampler.Sampler
OnUnevenDistributed = Literal["raise", "drop", "uneven", "ignore"]
[docs]
class AbstractEpochSampler(_BaseSampler, metaclass=abc.ABCMeta):
"""ABC for sampling based on epoch"""
epoch: int #:
_rank: int
_world_size: int
total: int
effective_total: int
def __init__(
self,
data_source: Sized,
init_epoch: int = 0,
on_uneven_distributed: OnUnevenDistributed = "raise",
):
self.effective_total = self.total = len(data_source)
self.epoch = argcheck.is_int(init_epoch, name="init_epoch")
on_uneven_distributed = argcheck.is_in(
on_uneven_distributed,
get_args(OnUnevenDistributed),
"on_uneven_distributed",
)
if (
on_uneven_distributed != "ignore"
and torch.distributed.is_available()
and torch.distributed.is_initialized()
and torch.distributed.get_rank() >= 0
):
self._rank = torch.distributed.get_rank()
self._world_size = torch.distributed.get_world_size()
if self.total % self._world_size:
if on_uneven_distributed == "raise":
raise ValueError(
f"dataset length ({self.total}) must be divisible by "
f"the distributed world size ({self._world_size}). Consult the "
"documentation for on_uneven_distributed"
)
elif on_uneven_distributed == "drop":
self.effective_total = self.total - (self.total % self._world_size)
else:
assert on_uneven_distributed == "uneven"
else:
self._rank = 0
self._world_size = 1
def __len__(self) -> int:
return (
self.effective_total - self._rank + self._world_size - 1
) // self._world_size
[docs]
@abc.abstractmethod
def get_samples_for_epoch_ignoring_distributed(self, epoch: int) -> Iterable[int]:
"""Get all samples for the provided epoch, ignoring the distruted environment
Ignores the distributed environment. All replicas should return the same value.
See Also
--------
get_samples_for_epoch
"""
...
[docs]
def get_samples_for_epoch(self, epoch: int) -> Iterable[int]:
"""Get all samples for the provided epoch"""
ret = self.get_samples_for_epoch_ignoring_distributed(epoch)
return islice(ret, self._rank, self.effective_total, self._world_size)
def __iter__(self) -> Iterator[int]:
ret = self.get_samples_for_epoch(self.epoch)
self.epoch += 1
return ret
[docs]
class EpochRandomSampler(AbstractEpochSampler):
"""A deterministic RandomSampler which handles :mod:`torch.distributed`
Parameters
----------
data_source
The dataset to draw the sample from.
init_epoch
The initial epoch.
base_seed
Determines the starting seed of the sampler. Sampling is seeded with
``(base_seed, epoch)```. If unset, a seed is randomly generated from the default
pytorch generator.
on_uneven_distributed
What to do if the sampler detects that it's in a distributed environment and
the number of processes does not evenly divide the number of samples:
- :obj:`'raise'` raise a :class:`ValueError`.
- :obj:`'drop'` drop the remainder. The dropped samples will be randomized each
epoch.
- :obj:`'uneven'` allow some processes to yield fewer samples.
- :obj:`'ignore'` ignore the distributed context. Each process will yield all
samples.
Warnings
--------
The default means of seeding the shuffler changed from version 0.3. Previously the
shuffler was seeded on each epoch with the value ``base_seed + epoch``. The change
means training a network in this version will yield different results from that
trained in version 0.3 even if `base_seed` is the same.
The change was made because, if repeated experiments were seeded sequentially, then
the ``n``-th epoch of the ``m``-th run would see samples in the same order as the
``m``-th epoch of the ``n``-th run. Thus, repeated trials were unintentionally
correlated.
Examples
--------
>>> sampler = EpochRandomSampler(
... torch.utils.data.TensorDataset(torch.arange(100)))
>>> samples_ep0 = tuple(sampler) # random
>>> samples_ep1 = tuple(sampler) # random, probably not same as first
>>> assert tuple(sampler.get_samples_for_epoch_ignoring_distributed(0)) == samples_ep0
>>> assert tuple(sampler.get_samples_for_epoch_ignoring_distributed(1)) == samples_ep1
"""
base_seed: int #:
def __init__(
self,
data_source: Sized,
init_epoch: int = 0,
base_seed: Optional[int] = None,
on_uneven_distributed: Literal["raise", "drop", "uneven", "ignore"] = "raise",
):
super().__init__(data_source, init_epoch, on_uneven_distributed)
max_ = np.iinfo(np.int32).max
if base_seed is None:
# we use numpy RandomState so that we can run in parallel with
# torch's RandomState, but we acquire the initial random seed from
# torch so that we'll be deterministic with a prior call to
# torch.manual_seed(...)
base_seed = torch.randint(max_, (1,)).long().item()
else:
base_seed = argcheck.is_int(base_seed, "base_seed")
base_seed = argcheck.is_lte(base_seed, max_, "base_seed")
self.base_seed = base_seed
def get_samples_for_epoch_ignoring_distributed(self, epoch: int) -> Iterable[int]:
rs = np.random.RandomState((self.base_seed, epoch))
shuffled = rs.permutation(self.total)
return iter(shuffled)
[docs]
class EpochSequentialSampler(AbstractEpochSampler):
"""A SequentialSampler which handles :mod:`torch.distributed`
Yields samples ``[1, 2, ...]``
Parameters
----------
data_source
The dataset to draw the sample from.
init_epoch
The initial epoch.
on_uneven_distributed
What to do if the sampler detects that it's in a distributed environment and
the number of processes does not evenly divide the number of samples:
- :obj:`'raise'` raise a :class:`ValueError`.
- :obj:`'drop'` drop the last few samples.
- :obj:`'uneven'` allow some processes to yield fewer samples.
- :obj:`'ignore'` ignore the distributed context. Each process will yield all
samples.
See the below note for more information.
Notes
-----
The following note regards how the sampler handles :mod:`torch.distributed`.
Sequential sampling in a distributed, parallel environment is not well defined. When
`on_uneven_distributed` is :obj:`'ignore'`, each process sees all data sequentially.
As such, every process repeats the same work and returns the same value. Though
wasteful, results are likely correct, and hence easiest to adapt to from a
non-distributed codebase (e.g. with
:class:`pydrobert.torch.training.TraningStateController`). Distributed sequential
sampling may still be appropriate otherwise when ordering does not matter, such as
when an evaluation metric is computed in aggregate.
When in a distributed environment and `on_uneven_distributed` is not :obj:`'ignore'`
process ``r`` of ``W`` processes will be responsible for samples ``[r, r + W, r +
2W, ...]`` (assuming `shifting` is :obj:`False`). When the total number of samples
``N`` is divisble by ``W``, each process sees the same number of samples and all
samples are yielded by exactly one process. Assuming the quantity of interest is an
average over all samples, computing the average per process and then that averaged
over processes should yield the same results.
When ``W`` does not divide ``N`` and `on_uneven_distributed` is :obj:`'uneven'`, all
samples will be yielded by exactly one process but not all processes will yield the
same number of samples. Averaging must be performed with specialized logic; see
:class:`torch.distributed.algorithms.Join` for one option.
Finally, when ``W`` does not divide ``N`` and `on_uneven_distributed` is
:obj:`'drop'`, the last ``N % W`` samples are dropped to ensure divisibility. Each
process will see the same number of samples, but the last few samples will never
be yielded. While averaging will almost always yield a different result from the
distributed case, it may nonetheless be close when ``N % W`` is small.
"""
def __init__(
self,
data_source: Sized,
init_epoch: int = 0,
on_uneven_distributed: Literal["raise", "drop", "uneven", "ignore"] = "raise",
):
super().__init__(data_source, init_epoch, on_uneven_distributed)
def get_samples_for_epoch_ignoring_distributed(self, epoch: int) -> Iterable[int]:
return range(self.total)
H = TypeVar("H", bound=Hashable)
[docs]
class BucketBatchSampler(_BaseSampler):
"""Batch samples into buckets, yielding as soon as the bucket is full
Parameters
----------
sampler
Determines the order in which samples are put into buckets.
idx2bucket
A map specifying which bucket each sample belongs to. The keys are the indices
yielded by `sampler`; the values are the ids of the corresponding buckets.
bucket2size
A map from the bucket ids (the values in `idx2bucket`) to the corresponding
batch size. Values must be positive.
drop_incomplete
If :obj:`True`, any batches which are incomplete (smaller than the bucket's
batch size) at the end of an epoch will be discarded. Otherwise, the incomplete
batches will be yielded in the order of their bucket ids' hashes.
Yields
------
batch : list of int
A list of indices from `sampler` all belonging to the same bucket. The batch is
yielded as soon as it is full (or the epoch has ended with `drop_incomplete` set
to :obj:`False`).
Warnings
--------
:class:`BucketBatchSampler` has no :func:`__len__` method. Correctly determining the
length of the batched sampler requires knowledge of which indices of `sampler` are
being iterated over which can only be determined by iterating over the `sampler`.
Examples
--------
>>> N = 14
>>> dataset = torch.utils.data.TensorDataset(torch.rand(N))
>>> ssampler = torch.utils.data.SequentialSampler(dataset)
>>> idx2bucket = dict((n, int(n % 3 == 0)) for n in range(N))
>>> bucket2size = {0: 2, 1: 2}
>>> bsampler = BucketBatchSampler(ssampler, idx2bucket, bucket2size, True)
>>> print(list(bsampler))
[[1, 2], [0, 3], [4, 5], [7, 8], [6, 9], [10, 11]]
>>> bsampler = BucketBatchSampler(ssampler, idx2bucket, bucket2size, False)
>>> print(list(bsampler))
[[1, 2], [0, 3], [4, 5], [7, 8], [6, 9], [10, 11], [13], [12]]
"""
sampler: Collection[int]
idx2bucket: Dict[int, H]
bucket2size: Dict[H, int]
drop_incomplete: bool
def __init__(
self,
sampler: Collection[int],
idx2bucket: Dict[int, H],
bucket2size: Dict[H, int],
drop_incomplete: bool = False,
):
self.sampler = sampler
self.idx2bucket = idx2bucket
self.bucket2size = bucket2size
self.drop_incomplete = argcheck.is_bool(drop_incomplete, "drop_incomplete")
def __iter__(self) -> Iterator[List[int]]:
batches: Dict[H, List[int]] = dict()
for idx in self.sampler:
hash_ = self.idx2bucket[idx]
batch_size = self.bucket2size[hash_]
batch = batches.setdefault(hash_, [])
batch.append(idx)
if batch_size == len(batch):
yield batch
del batches[hash_]
elif batch_size < len(batch):
raise RuntimeError(f"batch '{hash_}' has invalid size '{batch_size}'")
if not self.drop_incomplete:
for _, batch in sorted(batches.items(), key=lambda x: x[0]):
yield batch
[docs]
class DataLoaderParams(param.Parameterized):
"""General parameters for a DataSet from pydrobert.torch.data
This implements the :class:`pydrobert.param.optuna.TunableParameterized`
interface.
"""
batch_size = param.Integer(
10,
bounds=(1, None),
softbounds=(5, 10),
doc="Number of elements in a batch.",
)
drop_last = param.Boolean(
False,
doc="Whether to drop a batch when there are too few samples to match its size.",
)
[docs]
@classmethod
def get_tunable(cls) -> Set[str]:
"""Returns a set of tunable parameters"""
return {"batch_size"}
[docs]
@classmethod
def suggest_params(
cls, trial, base=None, only: Container[str] = None, prefix: str = ""
):
"""Populate a parameterized instance with values from trial"""
params = cls() if base is None else base
if only is None:
only = cls.get_tunable()
if "batch_size" in only:
bounds = params.param.params()["batch_size"].get_soft_bounds()
val = trial.suggest_int(prefix + "batch_size", *bounds)
params.batch_size = val
return params
[docs]
class DynamicLengthDataLoaderParams(DataLoaderParams):
"""Parameters for a data loader whose elements have dynamic lengths"""
num_length_buckets = param.Integer(
1,
bounds=(1, None),
doc="If greater than 1, elements will be batched with other elements of "
"similar length. For SpectDataSet, length is along the feature time dimension. "
"For LangDataSet, length is the reference sequence length. Elements will be "
"partioned roughly evenly into num_length_buckets. Increasing "
"num_length_buckets will usually decrease the total amount of padding "
"per batch at the cost of fewer candidates to choose from within batches.",
)
size_batch_by_length = param.Boolean(
False,
doc="Only matters when num_length_buckets > 1. If false, all buckets have the "
"same batch size of batch_size. If true, buckets with shorter-length "
"utterances will contain greater than batch_size elements per batch. Letting "
"x be the batch size of a bucket, y be the length of the largest element in "
"the bucket, and Y be the length of the largest element in the corpus, x is "
"the greatest value such that x * y <= Y * batch_size",
)
[docs]
class LangDataLoaderParams(LangDataParams, DynamicLengthDataLoaderParams):
"""Parameters for a :class:`LangDataLoader`
This implements the :class:`pydrobert.param.optuna.TunableParameterized` interface.
"""
pass
@overload
def lang_seq_to_batch(
seq: Sequence[torch.Tensor],
batch_first: bool = True,
sort: bool = True,
has_uttids: Literal[False] = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
...
@overload
def lang_seq_to_batch(
seq: Sequence[Tuple[torch.Tensor, str]],
batch_first: bool = True,
sort: bool = True,
has_uttids: Literal[True] = False,
) -> Tuple[torch.Tensor, torch.Tensor, Tuple[str, ...]]:
...
[docs]
def lang_seq_to_batch(
seq: Sequence[Union[torch.Tensor, Tuple[torch.Tensor, str]]],
batch_first: bool = True,
sort: bool = True,
has_uttids: bool = False,
) -> Union[
Tuple[torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor, Tuple[str, ...]],
]:
"""Convert a sequence of reference sequences to a batch
This function is used to collate sequences of elements from a :class:`LangDataSet`
into batches.
Parameters
----------
seq
A finite-length (``N``) sequence of either just `ref_n` or tuples ``ref_n,
utt_n``, where
- `ref_n` is a tensor of size ``(R_n[, 3])`` representing reference token
sequences and optionally their frame shifts. Either all `ref_n` must contain
the frame shift info (the ``3`` dimension) or none of them.
- `utt_n` (if `has_uttids` is :obj:`True`) is the utterance id.
batch_first
If :obj:`True`, the batch dimension ``N`` comes before the sequence dimension
``R`` in `refs`.
sort
If :obj:`True`, the elements of `seq` are ordered in descending order of ``R_n``
before being batched.
has_uttids
Whether `utt_n` is part of the input values and `uttids` is part of the output
values.
Returns
-------
batch : tuple
A tuple of ``refs, ref_sizes[, uttids]``, where: `refs` is a tensor of shape
``(max_n R_n, N[, 3])`` containing the right-padded sequences ``[ref_1, ref_2,
..., ref_N]`` and padded with :const:`pydrobert.torch.config.INDEX_PAD_VALUE`;
`ref_sizes` is a tensor of shape ``(N,)`` containing the sequence ``[R_1, R_2,
..., R_N]``; and `uttids` (if `has_uttids` is :obj:`True`), is an ``N``-tuple of
strings matching the utterance ids.
"""
if sort and has_uttids:
seq = sorted(seq, key=lambda x: x[0].size(0), reverse=True)
elif sort:
seq = sorted(seq, key=lambda x: x.size(0), reverse=True)
if has_uttids:
refs, uttids = zip(*seq)
else:
refs = seq
ref_sizes = torch.tensor([len(x) for x in refs])
refs = torch.nn.utils.rnn.pad_sequence(
refs, padding_value=config.INDEX_PAD_VALUE, batch_first=batch_first
)
if has_uttids:
return refs, ref_sizes, tuple(uttids)
else:
return refs, ref_sizes
def _get_batch_sampler_len(batch_sampler) -> int:
if isinstance(batch_sampler, BucketBatchSampler):
bucket2count = Counter(
batch_sampler.idx2bucket[i]
for i in batch_sampler.sampler.get_samples_for_epoch(
batch_sampler.sampler.epoch
)
)
len_ = 0
for bucket, count in bucket2count.items():
size = batch_sampler.bucket2size[bucket]
if batch_sampler.drop_incomplete:
len_ += count // size
else:
len_ += (count + size - 1) // size
else:
len_ = len(batch_sampler)
return len_
[docs]
class LangDataLoader(torch.utils.data.DataLoader):
"""DataLoader for a :class:`LangDataSet`
Parameters
----------
data
Either a :class:`LangDataSet` or a path to the data directory.
params
Contains at least the parameters specific to the loader. May also contain
data set params -- see `data_params`.
data_params
Data set parameters. Relevant only when `data` is a path. Used to initialize the
underlying :class:`LangDataSet`. If :obj:`None`, `params` is assumed to also
contain the data set parameters.
shuffle
Whether utterances are shuffled at every epoch or presented sequentially.
batch_first
Whether the batch dimension comes before the sequence dimension in `refs`.
sort_batch
Whether utterances in a batch are sorted by feature length.
init_epoch
The epoch to resume from. When combined with a fixed `seed`, ensures the same
batches are always delivered for a given epoch.
seed
The initial seed used for shuffling data. If unset, a random one will be
generated.
on_uneven_distributed
What to do if the sampler detects that it's in a distributed environment and the
number of processes does not evenly divide the number of samples:
- :obj:`'raise'` raise a :class:`ValueError`.
- :obj:`'uneven'` allow some processes to yield fewer samples.
- :obj:`'ignore'` ignore the distributed context. Each process will yield all
samples.
**kwargs
Additional keyword arguments to initialize :class:`LangDataSet` and
:class:`torch.utils.data.DataLoader`. The former is only relevant when
`data` is a path.
Yields
------
batch : Union[tuple, torch.Tensor]
A tuple ``refs, ref_lens[, utt_ids]``, with the presence of `utt_ids` dependent
on `suppress_uttids` in the underlying :class:`LangDataSet` is :obj:`True`). See
:func:`lang_seq_to_batch` for more information on the elements.
"""
dataset: LangDataSet
batch_first: bool
batch_sampler: Union[BucketBatchSampler, torch.utils.data.BatchSampler]
sort_batch: bool
_len: int
def __init__(
self,
data: Union[str, LangDataSet],
params: Union[LangDataLoaderParams, DynamicLengthDataLoaderParams],
data_params: Optional[LangDataParams] = None,
shuffle: bool = True,
batch_first: bool = True,
sort_batch: bool = False,
init_epoch: int = 0,
on_uneven_distributed: Literal["raise", "unordered", "ignore"] = "raise",
seed: Optional[int] = None,
**kwargs,
):
for bad_kwarg in {
"batch_sampler",
"batch_size",
"collate_fn",
"drop_last",
"sampler",
}:
if bad_kwarg in kwargs:
raise TypeError(
f"keyword argument '{bad_kwarg}' invalid for {type(self)} types"
)
ds_kwargs, dl_kwargs = dict(), dict()
for key, val in kwargs.items():
if key in {
"file_prefix",
"file_suffix",
"suppress_uttids",
"tokens_only",
}:
ds_kwargs[key] = val
else:
dl_kwargs[key] = val
if data_params is None:
data_params = params
batch_first = argcheck.is_bool(batch_first, "batch_first")
sort_batch = argcheck.is_bool(sort_batch, "sort_batch")
if isinstance(data, LangDataSet):
dataset = data
else:
dataset = LangDataSet(
data,
params=data_params,
**ds_kwargs,
)
utt_sampler_kwargs = {"init_epoch": argcheck.is_int(init_epoch, "init_epoch")}
if params.drop_last:
utt_sampler_kwargs["on_uneven_distributed"] = "drop"
else:
utt_sampler_kwargs["on_uneven_distributed"] = on_uneven_distributed
if shuffle:
utt_sampler = EpochRandomSampler(
dataset, base_seed=seed, **utt_sampler_kwargs
)
else:
utt_sampler = EpochSequentialSampler(dataset, **utt_sampler_kwargs)
if params.num_length_buckets > 1:
idx2bucket, bucket2size = _get_bucket_batch_sampler_params(
dataset,
params.num_length_buckets,
params.batch_size,
params.size_batch_by_length,
)
batch_sampler = BucketBatchSampler(
utt_sampler,
idx2bucket,
bucket2size,
params.drop_last,
)
else:
batch_sampler = torch.utils.data.BatchSampler(
utt_sampler, params.batch_size, drop_last=params.drop_last
)
super().__init__(
dataset,
batch_sampler=batch_sampler,
collate_fn=self.collate_fn,
**dl_kwargs,
)
self._len, self.batch_first, self.sort_batch = None, batch_first, sort_batch
def collate_fn(self, seq):
return lang_seq_to_batch(
seq,
self.batch_first,
self.sort_batch,
not self.dataset.suppress_uttids,
)
def __len__(self) -> int:
if self._len is None:
self._len = _get_batch_sampler_len(self.batch_sampler)
return self._len
@property
def epoch(self) -> int:
"""int : the current epoch"""
return self.batch_sampler.sampler.epoch
@epoch.setter
def epoch(self, val: int):
self.batch_sampler.sampler.epoch = val
[docs]
class SpectDataLoaderParams(SpectDataParams, LangDataLoaderParams):
"""Parameters for a :class:`SpectDataLoader`
This implements the :class:`pydrobert.param.optuna.TunableParameterized` interface.
"""
@classmethod
def get_tunable(cls) -> Set[str]:
return (
SpectDataParams.get_tunable() | DynamicLengthDataLoaderParams.get_tunable()
)
@classmethod
def suggest_params(
cls, trial, base=None, only: Container[str] = None, prefix: str = ""
):
params = cls() if base is None else base
SpectDataParams.suggest_params(trial, params, only, prefix)
DynamicLengthDataLoaderParams.suggest_params(trial, params, only, prefix)
return params
@overload
def spect_seq_to_batch(
seq: Sequence[Tuple[torch.Tensor, Optional[torch.Tensor]]],
batch_first: bool = True,
sort: bool = True,
has_alis: Literal[False] = True,
has_uttids: Literal[False] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor]]:
...
@overload
def spect_seq_to_batch(
seq: Sequence[Tuple[torch.Tensor, Optional[torch.Tensor], str]],
batch_first: bool = True,
sort: bool = True,
has_alis: Literal[False] = True,
has_uttids: Literal[True] = False,
) -> Tuple[
torch.Tensor,
Optional[torch.Tensor],
torch.Tensor,
Optional[torch.Tensor],
Tuple[str, ...],
]:
...
@overload
def spect_seq_to_batch(
seq: Sequence[Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]],
batch_first: bool = True,
sort: bool = True,
has_alis: Literal[True] = True,
has_uttids: Literal[False] = False,
) -> Tuple[
torch.Tensor,
Optional[torch.Tensor],
Optional[torch.Tensor],
torch.Tensor,
Optional[torch.Tensor],
]:
...
@overload
def spect_seq_to_batch(
seq: Sequence[
Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], str]
],
batch_first: bool = True,
sort: bool = True,
has_alis: Literal[True] = True,
has_uttids: Literal[True] = False,
) -> Tuple[
torch.Tensor,
Optional[torch.Tensor],
Optional[torch.Tensor],
torch.Tensor,
Optional[torch.Tensor],
Tuple[str, ...],
]:
...
[docs]
def spect_seq_to_batch(
seq: Sequence[Tuple[Union[torch.Tensor, str, None], ...]],
batch_first: bool = True,
sort: bool = True,
has_alis: bool = True,
has_uttids: bool = False,
) -> Tuple[Union[torch.Tensor, Tuple[str, ...], None], ...]:
"""Convert a sequence of spectral data to a batch
This function is used to collate sequences of elements from a :class:`SpectDataSet`
into batches.
Parameters
----------
seq
A finite-length (``N``) sequence of tuples, each tuple corresponding to an
utterance and containing, in order:
1. `feat_n`, a tensor of size ``(T_n, F)`` representing per-frame spectral
features.
2. `ali_n` (if `has_alis` is :obj:`True)`, either :obj:`None` or a tensor
of shape ``(T_n)`` representing per-frame alignment ids.
3. `ref_n`, either :obj:`None` or a tensor of size ``(R_n[, 3])`` representing
reference token sequences and optionally their frame shifts. Either all
`ref_n` must contain the frame shift info (the ``3`` dimension) or none of
them.
4. `utt_n` (if `has_uttids` is :obj:`True`), the utterance id.
batch_first
If :obj:`True`, the batch dimension ``N`` comes before the sequence dimension
``T`` or ``R`` in the return values.
sort
If :obj:`True`, the tuples in `seq` are first sorted in descending order of
``T_n`` before being batched.
has_alis
Whether `ali_n` is part of the input values and `alis` is part of the output
values. Note that `has_alis` should still be :obj:`True` if `ali_n` is present
in `seq` but is :obj:`None`.
has_uttids
Whether `utt_n` is part of the input values and `uttids` is part of the output
values.
Returns
-------
batch
A tuple containing the following elements:
1. `feats`, a tensor of shape ``(max_n T_n, N, F)`` containing the right-padded
sequences ``[feat_1, feat_2, ..., feat_N]``. Padded with zeros.
2. `alis` (if `has_alis` is :obj:`True`), either :obj:`None` or a tensor of
shape ``(max_n T_n, N)`` containing the right-padded sequence ``[ali_1,
ali_2, ... ali_N]``. Padded with
:const:`pydrobert.torch.config.INDEX_PAD_VALUE`.
3. `refs`, either :obj:`None` or a tensor of shape ``(max_n R_n, N[, 3])``
containing the right-padded sequences ``[ref_1, ref_2, ..., ref_N]``.
Padded with :const:`pydrobert.torch.config.INDEX_PAD_VALUE`.
4. `feat_sizes`, a tensor of shape ``(N,)`` containing the sequence ``[T_1, T_2,
..., T_N]``.
5. `ref_sizes`, a tensor of shape ``(N,)`` containing the sequence ``[R_1, R_2,
..., R_N]``.
6. `uttids` (if `has_uttids` is :obj:`True`), an ``N``-tuple of the utterance
ids.
"""
if sort:
seq = sorted(seq, key=lambda x: x[0].size(0), reverse=True)
seq = list(zip(*seq))
if has_alis:
if has_uttids:
feats, alis, refs, uttids = seq
else:
feats, alis, refs = seq
ali_not_none = all(x is not None for x in alis)
elif has_uttids:
feats, refs, uttids = seq
ali_not_none = False
else:
feats, refs = seq
ali_not_none = False
ref_not_none = all(x is not None for x in refs)
feat_sizes = torch.tensor([x.size(0) for x in feats])
feats = torch.nn.utils.rnn.pad_sequence(
feats, padding_value=0, batch_first=batch_first
)
if ali_not_none:
alis = torch.nn.utils.rnn.pad_sequence(
alis, padding_value=config.INDEX_PAD_VALUE, batch_first=batch_first
)
else:
alis = None
if ref_not_none:
ref_sizes = torch.tensor([len(x) for x in refs])
refs = torch.nn.utils.rnn.pad_sequence(
refs, padding_value=config.INDEX_PAD_VALUE, batch_first=batch_first
)
else:
ref_sizes = refs = None
if has_alis:
if has_uttids:
return feats, alis, refs, feat_sizes, ref_sizes, tuple(uttids)
else:
return feats, alis, refs, feat_sizes, ref_sizes
elif has_uttids:
return feats, refs, feat_sizes, ref_sizes, tuple(uttids)
else:
return feats, refs, feat_sizes, ref_sizes
def _get_bucket_batch_sampler_params(dataset, num_buckets, batch_size, dynamic):
elem_per_bucket = len(dataset) // num_buckets
if elem_per_bucket < batch_size:
warnings.warn(
f"The number of elements per bucket of the dataset ({elem_per_bucket}) "
f"is less than batch_size ({batch_size}). Consider decreasing "
"num_length_buckets"
)
len_idx = sorted((x[0].size(0), i) for (i, x) in enumerate(dataset))
len_bounds = [len_idx[(n + 1) * elem_per_bucket - 1][0] for n in range(num_buckets)]
len_bounds[-1] = len_idx[-1][0]
len_bounds_ = sorted(set(len_bounds))
if len_bounds_ != len_bounds:
warnings.warn(
f"Cannot evenly split dataset into {num_buckets} buckets. Decreasing to "
f"{len(len_bounds_)}"
)
len_bounds = len_bounds_
num_buckets = len(len_bounds)
idx2bucket = dict((i, sum(int(l > b) for b in len_bounds)) for (l, i) in len_idx)
if dynamic:
m = len_bounds[-1] * batch_size
bucket2size = dict((j, m // len_bounds[j]) for j in range(num_buckets))
else:
bucket2size = dict((j, batch_size) for j in range(num_buckets))
return idx2bucket, bucket2size
[docs]
class SpectDataLoader(torch.utils.data.DataLoader):
"""DataLoader for a :class:`SpectDataSet`
Parameters
----------
data
Either a :class:`SpectDataSet` or a path to the data directory.
params
Contains at least the parameters specific to the loader. May also contain
data set params -- see `data_params`.
data_params
Data set parameters. Relevant only when `data` is a path. Used to initialize
the underlying :class:`SpectDataSet`. If :obj:`None`, `params` is assumed to
also contain the data set parameters.
shuffle
Whether utterances are shuffled at every epoch or presented sequentially.
batch_first
Whether the batch dimension comes before the sequence dimension in `feats`
and `refs`.
sort_batch
Whether utterances in a batch are sorted by feature length.
init_epoch
The epoch to resume from. When combined with a fixed `seed`, ensures the same
batches are always delivered for a given epoch.
seed
The initial seed used for shuffling data. If unset, a random one will be
generated.
on_uneven_distributed
What to do if the sampler detects that it's in a distributed environment and the
number of processes does not evenly divide the number of samples:
- :obj:`'raise'` raise a :class:`ValueError`.
- :obj:`'uneven'` allow some processes to yield fewer samples.
- :obj:`'ignore'` ignore the distributed context. Each process will yield all
samples.
**kwargs
Additional keyword arguments to initialize :class:`SpectDataSet` and
:class:`torch.utils.data.DataLoader`. The former is only relevant when
`data` is a path.
Warnings
--------
:class:`SpectDataLoader` uses the default :obj:`True` for `suppress_alis` and
`tokens_only` while the current, deprecated default used by :class:`SpectDataSet`
is :obj:`False`.
Yields
------
batch
A tuple ``feats[, alis,] refs, feat_sizes, ref_sizes[, uttids]``, with `alis`
included if `suppress_alis` is :obj:`False` and `uttids` included if
`suppress_uttids` is :obj:`False`. See :func:`spect_seq_to_batch` for more
information on the elements.
"""
dataset: SpectDataSet
batch_first: bool
batch_sampler: Union[BucketBatchSampler, torch.utils.data.BatchSampler]
sort_batch: bool
_len: int
def __init__(
self,
data: Union[str, SpectDataSet],
params: Union[SpectDataLoaderParams, DynamicLengthDataLoaderParams],
data_params: Optional[SpectDataParams] = None,
shuffle: bool = True,
batch_first: bool = True,
sort_batch: bool = False,
init_epoch: int = 0,
on_uneven_distributed: Literal["raise", "unordered", "ignore"] = "raise",
seed: Optional[int] = None,
**kwargs,
):
for bad_kwarg in {
"batch_sampler",
"batch_size",
"collate_fn",
"drop_last",
"sampler",
}:
if bad_kwarg in kwargs:
raise TypeError(
f"keyword argument '{bad_kwarg}' invalid for {type(self)} types"
)
ds_kwargs, dl_kwargs = dict(), dict()
for key, val in kwargs.items():
if key in {
"file_prefix",
"file_suffix",
"warn_on_missing",
"subset_ids",
"sos",
"eos",
"feat_subdir",
"ali_subdir",
"ref_subdir",
"feat_mean",
"feat_std",
"suppress_alis",
"suppress_uttids",
"tokens_only",
}:
ds_kwargs[key] = val
else:
dl_kwargs[key] = val
if not isinstance(
params, (DynamicLengthDataLoaderParams, SpectDataLoaderParams)
) and isinstance(params, DataLoaderParams):
warnings.warn(
"Passing a DataLoaderParams instance as params is deprecated. "
"Switch to DynamicLengthDataLoaderParams.",
DeprecationWarning,
)
num_length_buckets = 1
else:
num_length_buckets = params.num_length_buckets
if data_params is None:
data_params = params
elif hasattr(params, "subset_ids"):
subset_ids = params.subset_ids
if subset_ids:
warnings.warn(
"setting subset_ids in data loader parameters is deprecated. "
"Use data_params.subset_ids instead.",
DeprecationWarning,
)
data_params.subset_ids = subset_ids
batch_first = argcheck.is_bool(batch_first, "batch_first")
sort_batch = argcheck.is_bool(sort_batch, "sort_batch")
if isinstance(data, SpectDataSet):
dataset = data
else:
suppress_alis = ds_kwargs.pop("suppress_alis", True)
tokens_only = ds_kwargs.pop("tokens_only", True)
dataset = SpectDataSet(
data,
params=data_params,
suppress_alis=suppress_alis,
tokens_only=tokens_only,
**ds_kwargs,
)
utt_sampler_kwargs = {"init_epoch": init_epoch}
if params.drop_last:
utt_sampler_kwargs["on_uneven_distributed"] = "drop"
else:
utt_sampler_kwargs["on_uneven_distributed"] = on_uneven_distributed
if shuffle:
utt_sampler = EpochRandomSampler(
dataset, base_seed=seed, **utt_sampler_kwargs
)
else:
utt_sampler = EpochSequentialSampler(dataset, **utt_sampler_kwargs)
if num_length_buckets > 1:
idx2bucket, bucket2size = _get_bucket_batch_sampler_params(
dataset,
num_length_buckets,
params.batch_size,
params.size_batch_by_length,
)
batch_sampler = BucketBatchSampler(
utt_sampler,
idx2bucket,
bucket2size,
params.drop_last,
)
else:
batch_sampler = torch.utils.data.BatchSampler(
utt_sampler, params.batch_size, drop_last=params.drop_last
)
super().__init__(
dataset,
batch_sampler=batch_sampler,
collate_fn=self.collate_fn,
**dl_kwargs,
)
self._len, self.batch_first, self.sort_batch = None, batch_first, sort_batch
def collate_fn(self, seq):
return spect_seq_to_batch(
seq,
self.batch_first,
self.sort_batch,
not self.dataset.suppress_alis,
not self.dataset.suppress_uttids,
)
def __len__(self) -> int:
if self._len is None:
self._len = _get_batch_sampler_len(self.batch_sampler)
return self._len
@property
def epoch(self) -> int:
"""int : the current epoch"""
return self.batch_sampler.sampler.epoch
@epoch.setter
def epoch(self, val: int):
self.batch_sampler.sampler.epoch = val
class SpectTrainingDataLoader(SpectDataLoader):
"""Serves batches of spectral data over random orders of utterances
Deprecated. Use :class:`SpectDataLoader`.
"""
def __init__(
self,
data: Union[str, SpectDataSet],
params: Union[SpectDataLoaderParams, DynamicLengthDataLoaderParams],
file_prefix: str = config.DEFT_FILE_PREFIX,
file_suffix: str = config.DEFT_FILE_SUFFIX,
warn_on_missing: bool = True,
feat_subdir: str = config.DEFT_FEAT_SUBDIR,
ali_subdir: str = config.DEFT_ALI_SUBDIR,
ref_subdir: str = config.DEFT_REF_SUBDIR,
init_epoch: int = 0,
batch_first: bool = True,
data_params: Optional[SpectDataParams] = None,
seed: Optional[int] = None,
**kwargs,
):
warnings.warn(
"SpectTrainingDataLoader is deprecated. Use SpectDataLoader instead",
DeprecationWarning,
)
shuffle = kwargs.pop("shuffle", True)
suppress_alis = kwargs.pop("suppress_alis", False)
suppress_uttids = kwargs.pop("suppress_uttids", True)
tokens_only = kwargs.pop("tokens_only", False)
sort_batch = kwargs.pop("sort_batch", True)
super().__init__(
data,
params,
data_params,
shuffle,
batch_first,
sort_batch,
init_epoch,
seed,
file_prefix=file_prefix,
file_suffix=file_suffix,
warn_on_missing=warn_on_missing,
feat_subdir=feat_subdir,
ali_subdir=ali_subdir,
ref_subdir=ref_subdir,
suppress_alis=suppress_alis,
suppress_uttids=suppress_uttids,
tokens_only=tokens_only,
**kwargs,
)
class SpectEvaluationDataLoader(SpectDataLoader):
"""Serves batches of spectral data over random orders of utterances
Deprecated. Use :class:`SpectDataLoader`.
"""
def __init__(
self,
data: Union[str, SpectDataSet],
params: Union[SpectDataLoaderParams, DynamicLengthDataLoaderParams],
file_prefix: str = config.DEFT_FILE_SUFFIX,
file_suffix: str = config.DEFT_FILE_SUFFIX,
warn_on_missing: bool = True,
feat_subdir: str = config.DEFT_FEAT_SUBDIR,
ali_subdir: str = config.DEFT_ALI_SUBDIR,
ref_subdir: str = config.DEFT_REF_SUBDIR,
batch_first: bool = True,
data_params: Optional[SpectDataParams] = None,
**kwargs,
):
warnings.warn(
"SpectEvaluationDataLoader is deprecated. Use SpectDataLoader instead",
DeprecationWarning,
)
shuffle = kwargs.pop("shuffle", False)
suppress_alis = kwargs.pop("suppress_alis", False)
suppress_uttids = kwargs.pop("suppress_uttids", False)
tokens_only = kwargs.pop("tokens_only", False)
init_epoch = kwargs.pop("init_epoch", 0)
seed = kwargs.pop("seed", None)
sort_batch = kwargs.pop("sort_batch", True)
super().__init__(
data,
params,
data_params,
shuffle,
batch_first,
sort_batch,
init_epoch,
seed,
file_prefix=file_prefix,
file_suffix=file_suffix,
warn_on_missing=warn_on_missing,
feat_subdir=feat_subdir,
ali_subdir=ali_subdir,
ref_subdir=ref_subdir,
suppress_alis=suppress_alis,
suppress_uttids=suppress_uttids,
tokens_only=tokens_only,
**kwargs,
)
[docs]
def context_window_seq_to_batch(
seq: Sequence[Tuple[Union[torch.Tensor, str, None], ...]], has_uttids: bool = False
) -> Tuple[Union[torch.Tensor, Sequence[str], None], ...]:
r"""Convert a sequence of context window elements to a batch
This function is used to collate sequences of elements from a
:class:`ContextWindowDataSet` into batches.
Assume `seq` is a finite length sequence of pairs of ``window, ali``, where
``window`` is of size ``(T, C, F)``, where ``T`` is some number of windows (which
can vary across elements in the sequence), ``C`` is the window size, and ``F`` is
some number filters, and ``ali`` is of size ``(T,)``. This method batches all the
elements of the sequence into a pair of ``windows, alis``, where `windows` and
`alis` will have shapes ``(N, C, F)`` and ``(N,)`` resp., where :math:`N = \sum T`
is the total number of context windows over the utterances.
If ``ali`` is :obj:`None` in any element, `alis` will also be :obj:`None`
Parameters
----------
seq
A finite-length (``N``) sequence of tuples, each tuple corresponding to an
utterance and containing, in order:
1. `window_n`, a tensor of size ``(T_n, C, F)`` representing windowed spectral
features.
2. `ali_n`, either :obj:`None` or a tensor of shape ``(T_n,)`` representing
per-window alignment ids.
3. `uttid_n` (if :obj:`has_refs` is :obj:`True`), the utterance id.
has_uttids
Whether `utt_n` is part of the input values and both `window_sizes` and `uttids`
are part of the output values.
Returns
-------
batch
A tuple containing the following elements:
1. `windows`, a tensor of shape ``(sum_n T_n, C, F)`` containing the
concatenated set of windows ``[window_1, window_2, ..., window_N]``
2. `alis`, either :obj:`None` or a tensor of shape ``(sum_n T_n,)`` containing
the concatenated alignment ids ``[ali_1, ali_2, ..., ali_N]``.
3. `window_sizes` (if `has_uttids` is :obj:`True`), a tensor of shape ``(N,)``
containing the sequence ``[T_1, T_2, ..., T_N]``.
4. `uttids` (if `has_uttids` is :obj:`True`), an ``N``-tuple of utterance ids.
"""
seq = list(zip(*seq))
if has_uttids:
windows, alis, uttids = seq
window_sizes = torch.tensor([w.size(0) for w in windows])
else:
windows, alis = seq
windows = torch.cat(windows)
if all(a is not None for a in alis):
alis = torch.cat(alis)
else:
alis = None
if has_uttids:
return windows, alis, window_sizes, tuple(uttids)
else:
return windows, alis
[docs]
class ContextWindowDataLoaderParams(ContextWindowDataParams, DataLoaderParams):
"""Parameters for a :class:`ContextWindowDataLoader`
This implements the :class:`pydrobert.param.optuna.TunableParameterized` interface.
"""
@classmethod
def get_tunable(cls):
"""Returns a set of tunable parameters"""
return DataLoaderParams.get_tunable() | ContextWindowDataParams.get_tunable()
@classmethod
def suggest_params(cls, trial, base=None, only=None, prefix=""):
"""Populate a parameterized instance with values from trial"""
params = DataLoaderParams.suggest_params(
trial, base=base, only=only, prefix=prefix
)
params = ContextWindowDataParams.suggest_params(
trial, base=params, only=only, prefix=prefix
)
return params
@classmethod
def get_tunable(cls) -> Set[str]:
return ContextWindowDataParams.get_tunable() | DataLoaderParams.get_tunable()
@classmethod
def suggest_params(
cls, trial, base=None, only: Container[str] = None, prefix: str = ""
):
params = cls() if base is None else base
ContextWindowDataParams.suggest_params(trial, params, only, prefix)
DataLoaderParams.suggest_params(trial, params, only, prefix)
return params
[docs]
class ContextWindowDataLoader(torch.utils.data.DataLoader):
"""DataLoader for :class:`ContextWindowDataSet`
Parameters
----------
data
Either a :class:`ContextWindowDataSet` or a path to the data directory.
params
Contains at least the parameters specific to the loader. May also contain
data set params --- see `data_params`.
data_params
Data set parameters. Relevant only when `data` is a path. Used to initialize
the underlying :class:`ContextWindowDataset`. If :obj:`None`, `params` is
assumed to also contain the data set parameters.
shuffle
Whether utterances are shuffled at every epoch or presented sequentially.
sort_batch
Whether utterances in a batch are sorted by feature length.
init_epoch
The epoch to resume from. When combined with a fixed `seed`, ensures the same
batches are always delivered for a given epoch.
seed
The initial seed used for shuffling data. If unset, a random one will be
generated.
**kwargs
Additional keyword arguments to initialize :class:`ContextWindowDataSet` and
:class:`torch.utils.data.DataLoader`. The former is only relevant when `data` is
a path.
Yields
------
batch
A tuple ``windows, alis[, window_sizes, uttids]``, with `window_sizes` and
`uttids` included if `suppress_uttids` is :obj:`False`. See
:func:`context_window_seq_to_batch` for more information on the elements.
Warnings
--------
This class does not currently support :mod:`torch.distributed`. Each process will
return the same batches.
"""
dataset: ContextWindowDataSet
batch_sampler: torch.utils.data.BatchSampler
batch_first: bool
def collate_fn(self, seq):
return context_window_seq_to_batch(seq, not self.dataset.suppress_uttids)
def __init__(
self,
data: Union[str, ContextWindowDataSet],
params: Union[ContextWindowDataLoaderParams, DataLoaderParams],
data_params: Optional[ContextWindowDataParams] = None,
shuffle: bool = True,
init_epoch: int = 0,
seed: Optional[int] = None,
**kwargs,
):
for bad_kwarg in (
"batch_size",
"sampler",
"batch_sampler",
"collate_fn",
"drop_last",
):
if bad_kwarg in kwargs:
raise TypeError(
'keyword argument "{}" invalid for {} types'.format(
bad_kwarg, type(self)
)
)
ds_kwargs, dl_kwargs = dict(), dict()
for key, val in kwargs.items():
if key in {
"left",
"right",
"file_prefix",
"file_suffix",
"warn_on_missing",
"subset_ids",
"feat_subdir",
"ali_subdir",
"reverse",
"feat_mean",
"feat_std",
"suppress_uttids",
}:
ds_kwargs[key] = val
else:
dl_kwargs[key] = val
if seed is None and hasattr(params, "seed"):
seed = params.seed
if data_params is None:
data_params = params
else:
if hasattr(params, "subset_ids"):
subset_ids = params.subset_ids
if subset_ids:
warnings.warn(
"setting subset_ids in data loader parameters is deprecated. "
"Use data_params.subset_ids instead.",
DeprecationWarning,
2,
)
data_params.subset_ids = subset_ids
if isinstance(data, ContextWindowDataSet):
dataset = data
data_dir = data.data_dir
else:
data_dir = data
dataset = ContextWindowDataSet(data_dir, params=data_params, **ds_kwargs)
if shuffle:
utt_sampler = EpochRandomSampler(dataset, init_epoch, seed, "ignore")
else:
utt_sampler = EpochSequentialSampler(dataset, init_epoch, "ignore")
batch_sampler = torch.utils.data.BatchSampler(
utt_sampler, params.batch_size, drop_last=params.drop_last
)
super().__init__(
dataset,
batch_sampler=batch_sampler,
collate_fn=self.collate_fn,
**dl_kwargs,
)
def __len__(self) -> int:
return len(self.batch_sampler)
@property
def epoch(self) -> int:
"""int : the current epoch"""
return self.batch_sampler.sampler.epoch
@epoch.setter
def epoch(self, val: int):
self.batch_sampler.sampler.epoch = val
class ContextWindowTrainingDataLoader(ContextWindowDataLoader):
"""Serve batches of context windows over a random order of utterances
Deprecated. Use :class:`ContextWindowDataLoader`.
"""
def __init__(
self,
data: Union[str, ContextWindowDataSet],
params: Union[ContextWindowDataLoaderParams, DataLoaderParams],
file_prefix: str = config.DEFT_FILE_PREFIX,
file_suffix: str = config.DEFT_FILE_SUFFIX,
warn_on_missing: bool = True,
feat_subdir: str = config.DEFT_FEAT_SUBDIR,
ali_subdir: str = config.DEFT_ALI_SUBDIR,
init_epoch: int = 0,
data_params: Optional[ContextWindowDataParams] = None,
seed: Optional[int] = None,
**kwargs,
):
warnings.warn(
"ContextWindowTrainingDataLoader is deprecated. Use "
"ContextWindowDataLoader instead",
DeprecationWarning,
)
shuffle = kwargs.pop("shuffle", True)
suppress_uttids = kwargs.pop("suppress_uttids", True)
super().__init__(
data,
params,
data_params,
shuffle,
init_epoch,
seed,
file_prefix=file_prefix,
file_suffix=file_suffix,
warn_on_missing=warn_on_missing,
feat_subdir=feat_subdir,
ali_subdir=ali_subdir,
suppress_uttids=suppress_uttids,
**kwargs,
)
class ContextWindowEvaluationDataLoader(ContextWindowDataLoader):
"""Serves batches of context windows over sequential utterances
Deprecated. Use :class:`ContextWindowDataLoader`.
"""
def __init__(
self,
data: Union[str, ContextWindowDataSet],
params: Union[ContextWindowDataLoaderParams, DataLoaderParams],
file_prefix: str = config.DEFT_FILE_PREFIX,
file_suffix: str = config.DEFT_FILE_SUFFIX,
warn_on_missing: bool = True,
feat_subdir: str = config.DEFT_FEAT_SUBDIR,
ali_subdir: str = config.DEFT_ALI_SUBDIR,
data_params: Optional[ContextWindowDataParams] = None,
**kwargs,
):
warnings.warn(
"ContextWindowEvaluationDataLoader is deprecated. Use "
"ContextWindowDataLoader instead",
DeprecationWarning,
)
shuffle = kwargs.pop("shuffle", False)
suppress_uttids = kwargs.pop("suppress_uttids", False)
init_epoch = kwargs.pop("init_epoch", 0)
seed = kwargs.pop("seed", None)
super().__init__(
data,
params,
data_params,
shuffle,
init_epoch,
seed,
file_prefix=file_prefix,
file_suffix=file_suffix,
warn_on_missing=warn_on_missing,
feat_subdir=feat_subdir,
ali_subdir=ali_subdir,
suppress_uttids=suppress_uttids,
**kwargs,
)