# Copyright 2023 Sean Robertson
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, overload
from typing_extensions import Literal, get_args
import torch
from . import config, argcheck
from ._compat import script, movedim
from ._wrappers import functional_wrapper, proxy
# FIXME(sdrobert): this should be traceable through the module, but this version of
# pytorch (1.8.1) isn't getting it
[docs]
@script
@functional_wrapper("MeanVarianceNormalization")
def mean_var_norm(
x: torch.Tensor,
dim: int = -1,
mean: Optional[torch.Tensor] = None,
std: Optional[torch.Tensor] = None,
eps: float = config.TINY,
):
D = x.ndim
if dim < -D or dim > D - 1:
raise IndexError(
f"Dimension out of range (expected to be in the range of [{-D},{D - 1}], "
f"got {dim})"
)
dim = (dim + D) % D
X = x.size(dim)
shape = [1] * D
shape[dim] = X
dtype = x.dtype
if mean is None:
mean = x.transpose(0, dim).unsqueeze(-1).flatten(1).double().mean(1)
x = x - mean.view(shape).to(dtype)
if std is None:
std = x.transpose(0, dim).unsqueeze(-1).flatten(1).double().std(1, False)
return (x / std.view(shape).to(x).clamp_min(eps)).to(dtype)
[docs]
class MeanVarianceNormalization(torch.nn.Module):
"""Normalize features according to mean and variance statistics
Given input `x`, population mean `mean`, population standard deviation `std`, and
some small value `eps`, mean-variance normalization for the ``i``-th element of the
`dim`-th dimension of `x` is defined as
::
y[..., i, ...] = (x[..., i, ...] - mean[i]) / max(std[i], eps).
The `mean` and `std` vectors can be acquired in three ways.
First, they may be passed directly to this module on initialization.
Second, if `mean` and `std` were not specified, they can be estimated from the
(biased) sample statistics of `x`. This is the same as unit normalization.
Third, they may be estimated from multiple instances of `x` by accumulating
sufficient statistics with the :func:`accumulate` method, then writing the biased
estimates with the :func:`store` method.
Parameters
----------
dim
The dimension to be normalized. All other dimensions are considered
mean
If set, a vector representing the population mean. The same size as `std`, if
specified.
std
If set, a vector representing the population standard deviation. The same size
as `mean`, if specified.
eps
A small non-negative floating-point value which ensures nonzero division if
positive.
Call Parameters
---------------
x : torch.Tensor
A tensor whose `dim`-th dimension is the same size as `mean` and `std`. To be
normalized.
Returns
-------
y : torch.Tensor
The normalized tensor of the same shape as `x`.
Examples
--------
>>> x = torch.arange(1000, dtype=torch.float).view(10, 10, 10)
>>> mean = x.flatten(0, 1).double().mean(0)
>>> std = x.flatten(0, 1).double().std(0, unbiased=False)
>>> y = MeanVarianceNormalization(-1, mean, std)(x)
>>> assert torch.allclose(y.flatten(0, 1).mean(0), torch.zeros(1))
>>> assert torch.allclose(y.flatten(0, 1).std(0, unbiased=False), torch.ones(1))
>>> mvn = MeanVarianceNormalization()
>>> y2 = mvn(x)
>>> assert torch.allclose(y, y2)
>>> for x_n in x:
... mvn.accumulate(x_n)
>>> mvn.store()
>>> assert torch.allclose(mvn.mean, mean)
>>> assert torch.allclose(mvn.std, std)
>>> y2 = mvn(x)
>>> assert torch.allclose(y, y2)
"""
__constants__ = ["dim", "eps"]
dim: int
eps: float
mean: Optional[torch.Tensor]
std: Optional[torch.Tensor]
count: Optional[torch.Tensor]
sum: Optional[torch.Tensor]
sumsq: Optional[torch.Tensor]
def __init__(
self,
dim: int = -1,
mean: Optional[torch.Tensor] = None,
std: Optional[torch.Tensor] = None,
eps: float = config.TINY,
):
dim = argcheck.is_int(dim, "dim")
if mean is not None:
mean = argcheck.is_tensor(mean, "mean")
mean = argcheck.has_ndim(mean, 1, "mean")
mean = argcheck.is_nonempty(mean, "mean")
if std is not None:
std = argcheck.is_tensor(std, "std")
std = argcheck.has_ndim(std, 1, "std")
std = argcheck.is_nonempty(std, "std")
if mean is not None and mean.size(0) != std.size(0):
raise ValueError(
"mean and std must be of the same length if both specified, got "
f"{mean.size(0)} and {std.size(0)}, respectively"
)
eps = argcheck.is_nonneg(eps, "eps")
super().__init__()
self.dim, self.eps = dim, eps
self.register_buffer("mean", mean)
self.register_buffer("std", std)
self.register_buffer("sum", None)
self.register_buffer("sumsq", None)
self.register_buffer("count", None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return mean_var_norm(x, self.dim, self.mean, self.std, self.eps)
__call__ = proxy(forward)
@torch.jit.export
def accumulate(self, x: torch.Tensor) -> None:
"""Accumulate statistics about mean and variance of input"""
if self.count is None:
assert self.sum is None and self.sumsq is None
X = x.size(self.dim)
self.count = torch.zeros(1, dtype=torch.double, device=x.device)
self.sum = torch.zeros(X, dtype=torch.double, device=x.device)
self.sumsq = torch.zeros(X, dtype=torch.double, device=x.device)
# XXX(sdrobert): this is so that torchscript can figure out the type refinement
count, sum_, sumsq = self.count, self.sum, self.sumsq
assert (
isinstance(count, torch.Tensor)
and isinstance(sum_, torch.Tensor)
and isinstance(sumsq, torch.Tensor)
)
x = x.transpose(0, self.dim).unsqueeze(-1).flatten(1)
count += x.size(1)
sum_ += x.sum(1)
sumsq += x.square().sum(1)
def extra_repr(self) -> str:
return f"dim={self.dim}, eps={self.eps:e}"
@torch.jit.export
def store(self, delete_stats: bool = True, bessel: bool = False) -> None:
"""Store mean and variance in internal buffers using accumulated statistics
Overwrites whatever mean and variance were previously stored in internal buffers
with those based off calls to :func:`accumulate`.
Parameters
----------
delete_stats
Whether to delete the accumulated statistics from internal buffers after the
mean and variance are stored.
bessel
Whether to perform `Bessel's correction
<https://en.wikipedia.org/wiki/Bessel's_correction>`__ on the variance.
Raises
------
RuntimeError
If the count of samples is too small to make the estimate. At least one
accumulated sample is necessary with `bessel` :obj:`False`; two if
:obj:`True`.
"""
if self.count is None:
raise RuntimeError("Too few accumulated statistics")
count, sum_, sumsq = self.count, self.sum, self.sumsq
assert (
isinstance(count, torch.Tensor)
and isinstance(sum_, torch.Tensor)
and isinstance(sumsq, torch.Tensor)
)
if count < 2:
raise RuntimeError("Too few accumulated statistics")
self.mean = mean = sum_ / count
var = sumsq / count - mean.square()
if bessel:
var *= count / (count - 1)
self.std = var.sqrt_()
if delete_stats:
self.sum = self.sumsq = self.count = None
def _feat_delta_filters(order: int, width: int) -> torch.Tensor:
if order < 0:
raise RuntimeError(f"order must be non-negative, got {order}")
if width < 1:
raise RuntimeError(f"width must be positive, got {width}")
last_filt = torch.zeros(1 + (2 * width) * order)
last_filt[width * order] = 1
filts = [last_filt]
if order == 0:
return last_filt.unsqueeze(0)
kernel = torch.arange(width, -width - 1, -1, dtype=torch.float)
kernel /= kernel.square().sum()
for _ in range(order):
last_filt = torch.nn.functional.conv1d(
last_filt.view(1, 1, -1), kernel.view(1, 1, -1), padding=width
).flatten()
filts.append(last_filt)
return torch.stack(filts)
[docs]
@script
@functional_wrapper("FeatureDeltas")
def feat_deltas(
x: torch.Tensor,
dim: int = -1,
time_dim: int = -2,
concatenate: bool = True,
order: int = 2,
width: int = 2,
pad_mode: str = "replicate",
value: float = config.DEFT_PAD_VALUE,
_filters: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if _filters is None:
_filters = _feat_delta_filters(order, width).to(x)
else:
assert _filters.shape == (order + 1, 1 + (2 * width) * order)
D = x.ndim
if time_dim < -D or time_dim >= D:
raise RuntimeError(
f"Expected dimension 'time_dim' to be in [{-D}, {D-1}], got " f"{time_dim}"
)
time_dim = (time_dim + D) % D
if not concatenate:
D += 1
if dim < -D or dim >= D:
raise RuntimeError(
f"Expected dimension 'dim' to be in [{-D}, {D-1}], got {dim}"
)
dim = (dim + D) % D
x = x.transpose(time_dim, -1)
shape = x.shape
x = x.unsqueeze(0).flatten(0, -2).unsqueeze(1)
if width:
x = torch.nn.functional.pad(x, (width * order, width * order), pad_mode, value)
x = torch.nn.functional.conv1d(x, _filters.unsqueeze(1))
x = x.view(shape[:-1] + (order + 1,) + shape[-1:])
x = x.transpose(-2, -1).transpose(time_dim, -2)
# the order dimension is moving right-to-left, so the original inhabitant of that
# dimension (if any) should be to its right after the move.
x = movedim(x, -1, dim)
if concatenate:
x = x.flatten(dim, dim + 1)
return x
PadMode = Literal["replicate", "constant", "reflect", "circular"]
[docs]
class FeatureDeltas(torch.nn.Module):
r"""Compute deltas of features
Letting :math:`x` be some input tensor with the `time_dim`-th dimension representing
the evolution of features over time. Denote that dimension with indices :math:`t`
and the dimension of the `order` of the deltas with :math:`u`. The :math:`0`-th
order deltas are just :math:`x` itself; higher order deltas are calculated
recursively as
.. math::
x[t, u] = \sum_{w=-width}^{width} x[t + w, u - 1] \frac{w}{\sum_{w'} w'^2}.
Deltas can be seen as a rolling averages: first-order deltas akin to first-order
derivatives; second-order to second-order, and so on.
Parameters
----------
dim
The dimension along which resulting deltas will be stored.
time_dim
The dimension along which deltas are calculated.
concatenate
If :obj:`True`, delta orders are merged into a single axis with the previous
occupants of the dimension `dim` via concatenation. Otherwise, a new dimension
is stacked into the location `dim`.
order
The non-negative maximum order of deltas.
width
Controls the width of the averaging window.
pad_mode
How to pad edges to ensure the same size output. See
:func:`torch.nn.functional.pad` for more details.
value
The value used in constant padding.
Call Parameters
---------------
x : torch.Tensor
Returns
-------
deltas : torch.Tensor
Has the same shape as `x` except for one dimension. If `concatenate` is false,
a new dimension is inserted into `x` at position `dim` of size ``order + 1``. If
`concatenate` is true, then the `dim`-th dimension of `deltas` is ``order + 1``
times the length of that of `x`.
"""
__constants__ = [
"dim",
"time_dim",
"concatenate",
"order",
"width",
"pad_mode",
"value",
]
dim: int
time_dim: int
order: int
width: int
pad_mode: str
value: float
filters: torch.Tensor
def __init__(
self,
dim: int = -1,
time_dim: int = -2,
concatenate: bool = True,
order: int = 2,
width: int = 2,
pad_mode: PadMode = "replicate",
value: float = config.DEFT_PAD_VALUE,
):
dim = argcheck.is_int(dim, "dim")
time_dim = argcheck.is_int(time_dim, "time_dim")
concatenate = argcheck.is_bool(concatenate, "concatenate")
order = argcheck.is_nonnegi(order, "order")
pad_mode = argcheck.is_in(pad_mode, get_args(PadMode), pad_mode)
super().__init__()
self.register_buffer("filters", _feat_delta_filters(order, width))
self.dim, self.time_dim, self.value = dim, time_dim, value
self.order, self.width, self.pad_mode = order, width, pad_mode
self.concatenate = concatenate
def forward(self, x: torch.Tensor) -> torch.Tensor:
return feat_deltas(
x,
self.dim,
self.time_dim,
self.concatenate,
self.order,
self.width,
self.pad_mode,
self.value,
self.filters,
)
def extra_repr(self) -> str:
return (
f"dim={self.dim}, time_dim={self.time_dim}, "
f"concatenate={self.concatenate}, order={self.order}, width={self.width}, "
f"pad_mode={self.pad_mode}, value={self.value}"
)
__call__ = proxy(forward)
Policy = Literal["fixed", "ali", "ref"]
WindowType = Literal["symmetric", "causal", "future"]
@overload
def slice_spect_data(
input: torch.Tensor,
in_lens: Optional[torch.Tensor] = None,
other_lens: Optional[torch.Tensor] = None,
policy: Policy = "fixed",
window_type: WindowType = "symmetric",
valid_only: bool = True,
lobe_size: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
...
[docs]
@script
@functional_wrapper("SliceSpectData")
def slice_spect_data(
input: torch.Tensor,
in_lens: Optional[torch.Tensor] = None,
other_lens: Optional[torch.Tensor] = None,
policy: str = "fixed",
window_type: str = "symmetric",
valid_only: bool = True,
lobe_size: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
if input.ndim < 2:
raise RuntimeError(
f"Expected input to be at least 2-dimensional; got {input.ndim}"
)
N, T = input.shape[:2]
device = input.device
if not T:
return (
torch.empty(0, 2, dtype=torch.long, device=device),
torch.empty(0, dtype=torch.long, device=device),
)
if lobe_size < 0:
raise RuntimeError(f"Expected non-negative lobe_size, got {lobe_size}")
if window_type not in ("symmetric", "causal", "future"):
raise RuntimeError(
"expected window_type to be one of 'symmetric', 'casual', or 'future'"
f"got '{window_type}'"
)
if policy == "fixed":
shift = lobe_size + 1
if valid_only and window_type == "symmetric":
window_size = 2 * lobe_size + 1
starts = torch.arange(0, max(T - window_size + 1, 0), shift, device=device)
ends = starts + window_size
mids = ends - 1
elif window_type == "symmetric":
window_size = 2 * lobe_size + 1
half_shift = shift // 2
TT = (T + half_shift) // shift
mids = torch.arange(TT, device=device) * shift + half_shift
starts = mids - window_size // 2
ends = starts + window_size
elif valid_only:
# the behaviour doesn't change with "causal" or "future" when valid_only
starts = torch.arange(0, max(T - lobe_size, 0), shift, device=device)
ends = starts + shift
mids = ends - 1
elif window_type == "causal":
starts = torch.arange(-lobe_size, T - lobe_size, shift, device=device)
ends = starts + shift
mids = ends - 1
else: # future
starts = mids = torch.arange(0, T, shift, device=device)
ends = starts + shift
starts, ends = starts.expand(N, -1), ends.expand(N, -1)
# starts = starts.clamp_min_(0).expand(N, -1)
# if in_lens is None:
# ends = ends.clamp_max_(T).expand(N, -1)
# else:
# ends = torch.min(ends.unsqueeze(0), in_lens.unsqueeze(1))
TT = starts.size(1)
slices = torch.stack([starts, ends], 2).flatten(end_dim=1)
sources = torch.arange(N, device=device).view(N, 1).expand(N, TT).flatten()
if in_lens is not None:
if in_lens.shape != (N,):
raise RuntimeError(
f"Expected in_lens to be of shape ({N},); got {in_lens.shape}"
)
mask = (in_lens.unsqueeze(1) > mids).flatten()
slices = slices[mask]
sources = sources[mask]
elif policy == "ali":
if input.ndim != 2:
raise RuntimeError(f"expected tensor of dimension 2 with policy 'ali'")
mask = input[:, :-1] != input[:, 1:]
arange = torch.arange(T, device=device)
if in_lens is not None:
if in_lens.shape != (N,):
raise RuntimeError(
f"Expected in_lens to be of shape ({N},); got {in_lens.shape}"
)
mask = mask & (in_lens.view(N, 1) > arange[1:])
else:
in_lens = torch.full((N,), T, device=device)
nonempty = (in_lens > 0).view(N, 1)
starts = torch.cat([nonempty, mask], 1).nonzero()
mask = torch.cat([torch.zeros_like(nonempty), mask], 1)
mask = mask | (nonempty & (in_lens.view(N, 1) == arange))
ends = mask.nonzero()
sources = starts[:, 0]
starts, ends = starts[:, 1], ends[:, 1]
if lobe_size:
NN = starts.size(0)
do_left = window_type in ("symmetric", "causal")
do_right = window_type in ("symmetric", "future")
if valid_only:
offs = (int(do_left) + int(do_right)) * lobe_size
is_same = sources[: NN - offs] == sources[offs:]
starts = starts[: NN - offs][is_same]
ends = ends[offs:][is_same]
sources = sources[: NN - offs][is_same]
else:
start_idx = torch.arange(NN, device=device)
end_idx = start_idx.clone()
for n in range(1, lobe_size + 1):
offs = (sources[n:] == sources[: NN - n]).long()
if do_left:
start_idx[n:] -= offs
if do_right:
end_idx[: NN - n] += offs
starts = starts[start_idx]
ends = ends[end_idx]
slices = torch.stack([starts, ends], 1)
elif policy == "ref":
if input.ndim != 3:
raise RuntimeError(f"Expected input to be 3-dimensional, got {input.ndim}")
if input.size(2) != 3:
raise RuntimeError(
f"Expected 3rd dimension of input to be of size 3, got {input.size(2)}"
)
starts = input[..., 1]
ends = input[..., 2]
if in_lens is None:
in_lens = torch.full((N,), T, device=device)
if other_lens is None:
# the final segment's end time
other_lens = (
ends[..., 1]
.gather(1, (in_lens - 1).clamp_min_(0).view(N, 1))
.squeeze(1)
.masked_fill_(in_lens == 0, 0)
)
elif other_lens.shape != (N,):
raise RuntimeError(
f"Expected other_lens to have shape ({N},); got {other_lens.shape}"
)
mask = in_lens.view(N, 1) > torch.arange(T, device=device)
mask = mask & (input[..., 1:] >= 0).all(2)
if window_type in ("symmetric", "causal"):
starts = starts - lobe_size
if window_type in ("symmetric", "future"):
ends = ends + lobe_size
if valid_only:
mask = mask & (starts >= 0) & (ends <= other_lens.view(N, 1))
else:
mask = mask & (ends > 0) & (starts < other_lens.view(N, 1))
mask = mask & (starts < ends)
starts, ends, mask = starts.flatten(), ends.flatten(), mask.flatten()
sources = torch.arange(N, device=device).view(N, 1).expand(N, T).flatten()
starts = starts[mask]
ends = ends[mask]
sources = sources[mask]
slices = torch.stack([starts, ends], 1)
else:
raise RuntimeError(
f"Expected policy to be one of 'fixed', 'ali', or 'ref'; got '{policy}'"
)
return slices, sources
[docs]
class SliceSpectData(torch.nn.Module):
"""Determine slices of feature chunks according to a variety of policies
This module helps to chunk :class:`pydrobert.data.SpectDataLoader` data (or other
similarly-structured tensors) into smaller units by returning slices of that data.
The input to this module and the means of determining those slices varies according
to the `policy` specified (see the notes below for more details). The return values
can then be used to slice the data.
Parameters
----------
policy
Specifies how to slice the data. If :obj:`'fixed'`, extract windows of fixed
length at fixed intervals. If :obj:`'ali'`, use changes in frame-level
alignments to determine segment boundaries and slice along those. If
:obj:`'ref'`, use token segmentations as slices. See below for more info.
window_type
How the window will be constructed around the "middle unit" in the policy. In
general :obj:`'symmetric'` adds lobes to either side of the middle unit,
:obj:`'causal'` to the left (towards :obj:`0`), :obj:`'future'` to the right.
valid_only
What to do when a would-be slice passes over the length of the data. If
:obj:`True`, any such slices are thrown out. If :obj:`False`, do something
dictated by the policy which may preserve the invalid boundaries.
lobe_size
Specifies the size of a lobe in the slice's window. When the `policy` is
:obj:`'fixed'` or :obj:`'ref'`, the unit of `lobe_size` is a single frame. When
`policy` is :obj:`'ali'`, the unit of `lobe_size` is a whole segment.
Call Parameters
---------------
input : torch.Tensor
A tensor of shape ``(N, T, *)``, where ``N`` is the batch dimension and ``T`` is
the (maximum) sequence dimension. When `policy` is :obj:`'fixed'`, `input`
should be the batch-first feature tensor `feats` from a
:class:`pydrobert.data.SpectDataLoader`. When :obj:`'ali'`, `input` should be
the batch-first `alis` tensor. When :obj:`'ref'`, `input` should be the
batch-first `refs` tensor with segment info.
in_lens : torch.Tensor, optional
A long tensor of shape ``(N,)`` specifying the lengths of sequences in `input`.
For the ``n``-th batch element, only the elements ``input[n, :inputlens[n]]``
are considered. If unspecified, all sequences are assumed to be of length ``T``.
For the :obj:`'fixed'` and :obj:`'ali'` policies, this is the `feat_lens` tensor
from a :class:`pydrobert.data.SpectDataLoader`. When :obj:`'ref'`, it is the
`ref_lens` tensor.
other_lens : torch.Tensor, optional
An additional long tensor of shape ``(N,)`` specifying some other lengths,
depending on the policy. It is currently only used in the :obj:`'ref'` policy
and takes the value `feat_lens` from a :class:`pydrobert.data.SpectDataLoader`.
Returns
-------
slices : torch.Tensor
A long tensor of shape ``(M, 2)`` storing the slices of all batch elements.
``M`` is the total number of slices. ``slices[m, 0]`` is the ``m``-th slice's
start index (inclusive), while ``slices[m, 1]`` is the ``m``-th slice's end
index (exclusive).
sources : torch.Tensor
A long tensor of shape ``(M,)`` where ``sources[m]`` is the batch index of the
``m``-th slice.
See Also
--------
ChunkBySlices
Can be used to chunk input using the returned `slices` (after reordering that
input with `sources`)
Notes
-----
If `policy` is :obj:`'fixed'`, slices are extracted at fixed intervals (``lobe_size
+ 1``) along the length of the data. `input` is assumed to be the data in question,
e.g. the `feats` tensor in a :class:`pydrobert.data.SpectDataLoader`, in batch-first
order (although any tensor which matches its first two dimensions will do).
`in_lens` may be used to specify the actual lengths of the input sequences if they
were padded to fit in the same batch element. If `window_type` is
:obj:`'symmetric'`, windows are of size ``1 + 2 * lobe_size``; otherwise, windows
are of size ``1 + lobe_size``. When `valid_only` is :obj:`True`, slices start at
index :obj:`0` and as many slices as can be fit fully within the sequences are
returned. When `valid_only` is :obj:`False` slices are kept if their "middle" index
lies before the end of the sequence with lobes clamped within the sequence. The
"middle" index for the symmetric window is at ``slice[0] + window_size // 2``; for
the causal window it's the last index of the window, ``slice[1] - 1``; for the
future window it's the first, ``slice[0]``. When `valid_only` is :obj:`False`, the
initial slice's offsets differ as well: for the symmetric case, it's ``(lobe_size +
1) // 2 - window_size // 2``; for the causal case, it's :obj:`-lobe_size`; and the
future case it's still :obj:`0`. As an example, given a sequence of length :obj:`8`,
the following are the slices under different configurations of the :obj:`'fixed'`
policy with a `lobe_size` of :obj:`2`::
[[0, 5], [3, 8]] # symmetric, valid_only
[[0, 3], [3, 6]] # not symmetric, valid_only
[[-1, 4], [2, 6], [5, 9]] # symmetric, not valid_only
[[-2, 1], [1, 4], [4, 7]] # causal, not valid_only
[[0, 3], [3, 6], [6, 9]] # future, not valid_only
If `policy` is :obj:`'ali'`, slices are extracted from the partition of the sequence
induced by per-frame alignments. `input` is assumed to be the alignments in
question, i.e. the batch-first `alis` tensor in a
:class:`pydrobert.data.SpectDataLoader`. `in_lens` may be used to specify the actual
lengths of the input sequences if they were padded to fit in the same batch element.
The segments are induced by `ali` as follows: a segment starts at index `t` whenever
``t == 0`` or ``alis[n, t - 1] != alis[n, t]``. Slice ``m`` is built from segment
``m`` by starting with the segment boundaries and possibly extending the start to
the left (towards :obj:`0`) or the end to the right (away from :obj:`0`). If
`window_type` is :obj:`'symmetric'` or :obj:`'causal'`, the ``m``-th segment's start
is set to the start of the ``(m - lobe_size)``-th. If `window_type` is
:obj:`'symmetric'` or :obj:`'future'`, the segment's end is set to the end of the
``(m + lobe_size)``-th. Since there are a finite number of segments, sometimes
either ``(m - lobe_size)`` or ``(m + lobe_size)`` will not exist. In that case and
if `only_valid` is :obj:`True`, the slice is thrown out. If `only_valid` is
:obj:`False`, the furthest segment from ``m`` in the same direction which also
exists will be used. For example, with ``input[n] = [1] * 4 + [2] * 3 + [1] + [5] *
2``, the following are the slices under different configurations of the :obj:`'ali'`
policy with a `lobe_size` of :obj:`1`::
[[0, 8], [4, 10]] # symmetric, valid_only
[[0, 7], [4, 8], [7, 10]] # not symmetric, valid_only
[[0, 7], [0, 8], [4, 10], [7, 10]] # symmetric, not valid_only
[[0, 4], [0, 7], [4, 8], [7, 10]] # causal, not valid_only
[[0, 7], [4, 8], [7, 10], [8, 10]] # future, not valid_only
Finally, if `policy` is :obj:`'ref'`, slices are extracted from a transcription's
segment boundaries. `input` is assumed to be the token sequences in question, i.e.
the batch-first `refs` tensor in a :class:`pydrobert.data.SpectDataLoader`. `input`
should be 3-dimensional with the third dimension of size 3: ``input[..., 0]`` the
token sequence (ignored), ``input[..., 1]`` the segment starts (in frames), and
``input[..., 2]`` their ends. `in_lens` may be specified to give the length of the
token sequences (i.e. `ref_lens`). In addition, the lengths of the sequences `input`
is segmenting (in frames) may be passed via `other_lens` (i.e. `feat_lens`). The
slices are built off the available segments. If `window_type` is :obj:`'causal'`,
`lobe_size` is subtracted from all segments if :obj:`'future'`, `lobe_size` is added
to all ends; if :obj:`'symmetric'`, both are applied. A segment may be discarded a
few ways: if either the start or end frame is less than 0 (indicating missing
segment information); if `in_lens` is set and the token segment is indexed past that
length (``input[n, t]`` for any ``t >= in_lens[n]``); the starting frame of a
segment (after padding) matches or exceeds the ending frame after padding (no empty
or invalid slices); if :obj:`valid_only` is :obj:`True` and the padded start begins
before index :obj:`0` or the padded end ends after `other_lens`; and if
:obj:`valid_only` is :obj:`False` and the padded start begins after `other_lens` or
ends at or before :obj:`0`. For example, with ``input[n] = [[1, 0, 0], [2, 2, 3],
[3, -1, 1], [4, 0, -1], [5, 3, 5], [6, 4, 4]``, `in_lens[n] = 5``, ``other_lens[n] =
6``, and `lobe_size` of :obj:`2`, the following are the slices under different
configurations of the :obj:`'ref'` policy::
[[0, 5]] # symmetric, valid_only
[[0, 3], [1, 5]] # causal, valid_only
[[0, 2], [2, 5]] # future, valid_only
[[-2, 2], [0, 5], [1, 7]] # symmetric, not valid_only
[[0, 3], [1, 5]] # causal, not valid_only
[[0, 2], [2, 5], [3, 7]] # future, not valid_only
"""
__constants__ = "policy", "window_type", "valid_only", "lobe_size"
policy: str
window_type: str
valid_only: bool
lobe_size: int
def __init__(
self,
policy: Policy = "fixed",
window_type: WindowType = "symmetric",
valid_only: bool = True,
lobe_size: int = 0,
) -> None:
policy = argcheck.is_in(policy, get_args(Policy), "policy")
window_type = argcheck.is_in(window_type, get_args(WindowType), "window_type")
valid_only = argcheck.is_bool(valid_only, "valid_only")
lobe_size = argcheck.is_nonnegi(lobe_size, "lobe_size")
super().__init__()
self.policy, self.window_type, self.lobe_size = policy, window_type, lobe_size
self.valid_only = valid_only
def extra_repr(self) -> str:
return (
f"policy={self.policy}, window_type={self.window_type}, "
f"lobe_size={self.lobe_size}, valid_only={self.valid_only}"
)
def forward(
self,
input: torch.Tensor,
in_lens: Optional[torch.Tensor] = None,
other_lens: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
return slice_spect_data(
input,
in_lens,
other_lens,
self.policy,
self.window_type,
self.valid_only,
self.lobe_size,
)
__call__ = proxy(forward)
[docs]
@script
@functional_wrapper("ChunkTokenSequencesBySlices")
def chunk_token_sequences_by_slices(
refs: torch.Tensor,
slices: torch.Tensor,
ref_lens: Optional[torch.Tensor] = None,
partial: bool = False,
retain: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
if refs.ndim == 2:
return refs.new_empty((0, refs.size(1),)), slices.new_empty((0,))
elif refs.ndim != 3 or refs.size(2) != 3:
raise RuntimeError(
"Expected refs to be 2-dimensional or 3-dimensional with final "
f"dimension size 3. Got shape '{refs.shape}'"
)
N, R = refs.size(0), refs.size(1)
if slices.shape != (N, 2):
raise RuntimeError(
f"Expected slices to be a tensor of shape ({N}, 2), got {slices.shape}"
)
arange = torch.arange(R, device=refs.device)
if ref_lens is None:
mask = torch.ones((N, R), device=refs.device, dtype=torch.bool)
elif ref_lens.shape != (N,):
raise RuntimeError(
f"Expected ref_lens to be a tensor of shape ({N},), got {ref_lens.shape}"
)
else:
mask = ref_lens.unsqueeze(1) > arange
mask = mask & (refs[..., 1:] >= 0).all(2) & (refs[..., 2] >= refs[..., 1])
if partial:
# slice_start < ref_end and slice_end > ref_start
mask = (
mask & (slices[..., :1] < refs[..., 2]) & (slices[..., 1:] > refs[..., 1])
)
else:
# slice_start <= ref_start and slice_end >= ref_end
mask = (
mask & (slices[..., :1] <= refs[..., 1]) & (slices[..., 1:] >= refs[..., 2])
)
chunked_lens = mask.long().sum(1)
refs = refs[mask.unsqueeze(2).expand_as(refs)]
mask = (chunked_lens.unsqueeze(1) > arange).unsqueeze(2).expand(N, R, 3)
chunked = refs.new_empty((N, R, 3)).masked_scatter_(mask, refs)
if not retain:
chunked[..., 1:] += slices[..., 0].view(N, 1, 1).expand(N, R, 2)
return chunked, chunked_lens
[docs]
class ChunkTokenSequencesBySlices(torch.nn.Module):
"""Chunk token sequences with segments in slices
Parameters
----------
partial
If :obj:`True`, a segment of `refs` whose interval partially overlaps with the
slice will be included in `chunked`. Otherwise, segments in `ref` must fully
overlap with slices (i.e. be contained within).
retain
If :obj:`True`, tokens kept from `refs` will retain their original boundary
values. Otherwise, boundaries will become relative to the start frame of
`slices`.
Call Parameters
---------------
refs : torch.Tensor
A long tensor of shape ``(N, R, 3)`` containing triples ``tok, start, end``,
where `tok` is the token id, `start` is the start frame (inclusive) of the
segment, and `end` is its end frame (exclusive). A negative `start` or `end` is
treated as a missing boundary and will automatically exclude the triple from the
chunk. `ref` may also be a 2-dimensional long tensor ``(N, R)`` of tokens,
excluding segment boundaried. However, the return values will always be empty.
slices : torch.Tensor
A long tensor of shape ``(N, 2)`` containing pairs ``start, end``, where `start`
and `end` are the start (inclusive) and end (exclusive) indices, respectively.
ref_lens : torch.Tensor, optional
An optional long tensor of shape ``(N,)`` specifying the token sequence lengths.
Only the values in the range ``refs[n, :ref_lens[n]]`` are considered part of
the sequence of batch element ``n``. If unspecified, all token sequences of
`refs` are assumed to be of length ``R``.
Returns
-------
chunked : torch.Tensor
A long tensor of shape ``(N, R', 3)`` of the chunked token sequences.
chunked_lens : torch.Tensor
A long tensor of shape ``(N,)`` with the same interpretation as `ref_lens`, but
for `chunked` instead.
Warnings
--------
Negative indices in slices in Python are usually interpreted as an offset left from
the end of the sequence. In `slices`, however, negative indices indicate an offset
left from the start of the sequence. In `refs`, negative indices indicate a missing
boundary and are thrown out. Negative indices in `slices` can impact the returned
segment boundaries in `chunked`.
See Also
--------
SliceSpectData
Can be used to determine appropriate `slices`. In this case, ``refs =
refs[sources]`` and ``ref_lens = ref_lens[sources]`` should be passed to this
module (using the return value `sources` from :class:`SliceSpectData`).
ChunkBySlices
A similar purpose, but for input with an explicit dimension for slicing, such as
`feats` or `alis` from :class:`SpectDataSet`.
"""
__constants__ = "partial", "retain"
partial: bool
retain: bool
def __init__(self, partial: bool = False, retain: bool = False) -> None:
partial = argcheck.is_bool(partial, "partial")
retain = argcheck.is_bool(retain, "retain")
super().__init__()
self.partial, self.retain = partial, retain
def extra_repr(self) -> str:
if self.partial and self.retain:
return "partial, retain"
elif self.partial:
return "partial"
elif self.retain:
return "retain"
else:
return ""
def forward(
self,
ref: torch.Tensor,
slices: torch.Tensor,
ref_lens: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
return chunk_token_sequences_by_slices(
ref, slices, ref_lens, self.partial, self.retain
)
# __call__ = proxy(forward)