# Copyright 2022 Sean Robertson
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, overload
from typing_extensions import Literal, get_args
import torch
from . import config, argcheck
from ._compat import script
from ._wrappers import functional_wrapper, proxy
PadMode = Literal["constant", "reflect", "replicate"]
@overload
def pad_variable(
x: torch.Tensor,
lens: torch.Tensor,
pad: torch.Tensor,
mode: PadMode = "constant",
value: float = config.DEFT_PAD_VALUE,
) -> torch.Tensor:
...
@script
def _get_padding_buffers(
x: torch.Tensor,
lens: torch.Tensor,
left_pad: torch.Tensor,
right_pad: torch.Tensor,
mode: str,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.ndim == 3
N, T, F = x.shape
arange = torch.arange(T, device=x.device)
if mode == "constant":
# don't actually do anything. It'll be faster if we just initialize
# with the fill value
left_buf = right_buf = x
# buff = torch.tensor(value, device=device).to(dtype).view(1)
# left_buf = buff.expand(left_pad.sum() * F)
# right_buf = buff.expand(right_pad.sum() * F)
elif mode == "reflect":
if (left_pad >= lens).any() or (right_pad >= lens).any():
raise NotImplementedError(
"For reflect padding, all padding lengths must be less than the "
"sequence length"
)
left_mask = (left_pad.unsqueeze(1) > arange).unsqueeze(2).expand_as(x)
left_max, right_max = left_pad.max(), right_pad.max()
left_idxs = (
(left_pad.unsqueeze(1) - arange[:left_max])
.clamp_(min=0)
.unsqueeze(2)
.expand(N, left_max, F)
)
left_buf = x.gather(1, left_idxs).masked_select(left_mask[:, :left_max])
right_idxs = (
(lens.unsqueeze(1) - arange[:right_max] - 2)
.clamp_(min=0)
.unsqueeze(2)
.expand(N, right_max, F)
)
right_mask = (
(right_pad.unsqueeze(1) > arange[:right_max])
.unsqueeze(2)
.expand(N, right_max, F)
)
right_buf = x.gather(1, right_idxs).masked_select(right_mask)
elif mode == "replicate":
if (lens < 1).any():
raise RuntimeError(f"For replicate padding, all lens must be > 0")
left_mask = (left_pad.unsqueeze(1) > arange).unsqueeze(2).expand_as(x)
left_max, right_max = left_pad.max(), right_pad.max()
left_buf = (
x[:, :1].expand(N, left_max, F).masked_select(left_mask[:, :left_max])
)
right_mask_ = (
(right_pad.unsqueeze(1) > arange[:right_max])
.unsqueeze(2)
.expand(N, right_max, F)
)
right_buf = (
x.gather(1, (lens - 1).view(N, 1, 1).expand(N, right_max, F))
.expand(N, right_max, F)
.masked_select(right_mask_[:, :right_max])
)
else:
raise ValueError(
f"mode must be one of 'constant', 'reflect', 'replicate', got '{mode}'"
)
return left_buf, right_buf
[docs]
@script
@functional_wrapper("PadVariable")
def pad_variable(
x: torch.Tensor,
lens: torch.Tensor,
pad: torch.Tensor,
mode: str = "constant",
value: float = config.DEFT_PAD_VALUE,
) -> torch.Tensor:
if x.ndim < 2:
raise ValueError("Expected x to be at least two dimensional")
shape = x.shape
N, T = shape[:2]
if lens.shape != (N,):
raise ValueError(
f"For x of shape {shape}, lens should have shape ({N},) but got"
f"{lens.shape}"
)
if pad.shape != (2, N):
raise ValueError(
f"For x of shape {shape}, pad should have shape (2, {N}), but got "
f"{pad.shape}"
)
x = x.unsqueeze(-1).flatten(2)
F = x.size(2)
left_buf, right_buf = _get_padding_buffers(x, lens, pad[0], pad[1], mode)
new_lens = lens + pad.sum(0)
Tp = int(new_lens.max().item())
arange = torch.arange(max(Tp, T), device=x.device)
left_mask = (pad[0].unsqueeze(1) > arange[:Tp]).unsqueeze(2).expand(N, Tp, F)
mid_mask = (
((pad[0] + lens).unsqueeze(1) > arange[:Tp]).unsqueeze(2).expand(N, Tp, F)
)
right_mask = (new_lens.unsqueeze(1) > arange[:Tp]).unsqueeze(2).expand(N, Tp, F)
len_mask = (lens.unsqueeze(1) > arange[:T]).unsqueeze(2).expand(N, T, F)
padded = x.new_full((N, Tp, F), value)
x = x.masked_select(len_mask)
padded = padded.masked_scatter(mid_mask & ~left_mask, x)
if mode != "constant":
padded = padded.masked_scatter(left_mask, left_buf)
padded = padded.masked_scatter(right_mask & ~mid_mask, right_buf)
return padded.view((N, Tp) + shape[2:])
[docs]
class PadVariable(torch.nn.Module):
"""Pad variable-length input by a variable amount on each side
This module attempts to replicate the behaviour of :func:`torch.nn.functional.pad`
on a tensor containing variable sequence lengths with variable amounts of padding.
Parameters
----------
mode
How to pad the sequences. :obj:`'constant'`: fill the padding region with the
value specified by `value`. :obj:`'reflect'`: padded values are reflections
around the endpoints. For example, the first right-padded value of the ``n``-th
sequence would be ``x[n, lens[n] - 2``, the third ``x[n, lens[n] - 3]``, and
so on. :obj:`replicate`: padding duplicates the endpoints of each sequence.
For example, the left-padded values of the ``n``-th sequence would all be
``x[n, 0]``; the right-padded values would be ``x[n, lens[n] - 1]``.
value
The value to pad with when ``mode == 'constant'``.
Call Parameters
---------------
x : torch.Tensor
A tensor of shape ``(N, T, *)`` where ``N`` is the batch index and ``T`` is
the sequence index.
lens : torch.Tensor
A long tensor of shape ``(N,)`` specifying the sequence lengths. Only the values
in the range ``x[n, :lens[n]]`` are considered part of the sequence of batch
element ``n``.
pad : torch.Tensor
A long tensor of shape ``(2, N)`` specifying how many elements at the start
(``pad[0]``) and end (``pad[1]``) of each sequence.
Returns
-------
padded : torch.Tensor
A tensor of shape ``(N, T', *)`` such that, for a given batch index ``n``::
padded[n, :pad[0, n]] = left padding
padded[n, pad[0,n]:pad[0,n] + lens[n]] = x[n, :lens[n]]
padded[n, pad[0,n] + lens[n]:pad[0,n] + lens[n] + pad[1, n]] = right padding
Raises
------
NotImplementedError
If any value in ``pad[:, n]`` equals or exceeds ``lens[n]`` when
``mode == 'reflect'``
RuntimeError
If any element in `lens` is less than 1 when ``mode == 'replicate'``
Examples
--------
>>> x = torch.arange(10).view(2, 5)
>>> x
tensor([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9]])
>>> lens = torch.tensor([3, 4])
>>> pad = torch.arange(4).view(2, 2)
>>> pad.t() # [[0_left, 0_right], [1_left, 1_right]]
tensor([[0, 2],
[1, 3]])
>>> y = pad_variable(x, lens, pad) # constant w/ value 0
>>> y[0, :3 + 0 + 2]
tensor([0, 1, 2, 0, 0])
>>> y[1, :4 + 1 + 3]
tensor([0, 5, 6, 7, 8, 0, 0, 0])
>>> y = pad_variable(x, lens, pad, 'reflect')
>>> y[0, :3 + 0 + 2]
tensor([0, 1, 2, 1, 0])
>>> y[1, :4 + 1 + 3]
tensor([6, 5, 6, 7, 8, 7, 6, 5])
>>> y = pad_variable(x, lens, pad, 'replicate')
>>> y[0, :3 + 0 + 2]
tensor([0, 1, 2, 2, 2])
>>> y[1, :4 + 1 + 3]
tensor([5, 5, 6, 7, 8, 8, 8, 8])
"""
__constants__ = "mode", "value"
mode: str
value: float
def __init__(
self, mode: PadMode = "constant", value: float = config.DEFT_PAD_VALUE,
):
mode = argcheck.is_in(mode, get_args(PadMode), "mode")
value = argcheck.is_float(value, "value")
super().__init__()
self.mode, self.value = mode, value
def extra_repr(self) -> str:
s = f"mode={self.mode}"
if self.mode == "constant":
s += f", value={self.value}"
return s
def forward(
self, x: torch.Tensor, lens: torch.Tensor, pad: torch.Tensor
) -> torch.Tensor:
return pad_variable(x, lens, pad, self.mode, self.value)
__call__ = proxy(forward)
[docs]
@script
@functional_wrapper("PadMaskedSequence")
def pad_masked_sequence(
x: torch.Tensor,
mask: torch.Tensor,
batch_first: bool = False,
padding_value: float = config.DEFT_PAD_VALUE,
) -> Tuple[torch.Tensor, torch.Tensor]:
if x.ndim < 2:
raise RuntimeError(f"expected x to be at least two-dimensional, got {x.ndim}")
if mask.ndim != 2:
raise RuntimeError(f"expected mask to be two-dimensional, got {mask.ndim}")
if not batch_first:
x, mask = x.transpose(0, 1), mask.transpose(0, 1)
lens = mask.sum(1)
lmask = lens.unsqueeze(1) > torch.arange(x.size(1), device=lens.device)
lmask = lmask.view(lmask.shape + (1,) * (x.ndim - 2)).expand_as(x)
mask = mask.view(mask.shape + (1,) * (x.ndim - 2)).expand_as(x)
x_ = torch.full_like(x, padding_value)
x_ = x_.masked_scatter(lmask, x.masked_select(mask))
if not batch_first:
x_ = x_.transpose(0, 1)
return x_, lens
[docs]
class PadMaskedSequence(torch.nn.Module):
"""Select masked elements of tensor, then scatter into right-padded sequences
Parameters
----------
batch_first
Whether the first (or second) dimension of `x` is the batch dimension. The
sequence dimension will be the second (or first).
padding_value
The value to right-pad the remaining elements with along the sequence dimension.
Call Parameters
---------------
x : torch.Tensor
The input tensor. At least two dimensional.
mask : torch.Tensor
A boolean tensor whose :obj:`True` values indicate that the associated
element(s) of `x` should be included in the sequence. Broadcasts with the
first two dimensions of `x`.
Returns
-------
x_ : torch.Tensor
A tensor of the same shape as `x` such that, supposing ``i`` indexes the
``j``-th :obj:`True` element of `mask` for batch index :obj:`n`::
x_[j, n] = x[i, n]
with the remaining values of `x_` being `padding_value`.
lens : torch.Tensor
A vector of the length of the batch dimension which counts the number of
elements of `x` stored in `x_` per batch element.
Examples
--------
>>> x = torch.arange(100).view(10, 10)
>>> mask = (x % 3) == 0
>>> pad_masked_sequence = PadMaskedSequence(True, -1)
>>> x_, lens = pad_masked_sequence(x, mask)
>>> x_
tensor([[ 0, 3, 6, 9, -1, -1, -1, -1, -1, -1],
[12, 15, 18, -1, -1, -1, -1, -1, -1, -1],
[21, 24, 27, -1, -1, -1, -1, -1, -1, -1],
[30, 33, 36, 39, -1, -1, -1, -1, -1, -1],
[42, 45, 48, -1, -1, -1, -1, -1, -1, -1],
[51, 54, 57, -1, -1, -1, -1, -1, -1, -1],
[60, 63, 66, 69, -1, -1, -1, -1, -1, -1],
[72, 75, 78, -1, -1, -1, -1, -1, -1, -1],
[81, 84, 87, -1, -1, -1, -1, -1, -1, -1],
[90, 93, 96, 99, -1, -1, -1, -1, -1, -1]])
>>> lens
tensor([4, 3, 3, 4, 3, 3, 4, 3, 3, 4])
>>> x = (x * 2).unsqueeze(2) + torch.arange(2)
>>> x_, lens = pad_masked_sequence(x, mask)
>>> x_[:1]
tensor([[[ 0, 1],
[ 6, 7],
[12, 13],
[18, 19],
[-1, -1],
[-1, -1],
[-1, -1],
[-1, -1],
[-1, -1],
[-1, -1]]])
"""
__constants__ = "batch_first", "padding_value"
batch_first: bool
padding_value: float
def __init__(
self, batch_first: bool = False, padding_value: float = config.DEFT_PAD_VALUE
):
batch_first = argcheck.is_bool(batch_first, "batch_first")
padding_value = argcheck.is_float(padding_value, "padding_value")
super().__init__()
self.batch_first, self.padding_value = batch_first, padding_value
def extra_repr(self) -> str:
return f"batch_first={self.batch_first}, padding_value={self.padding_value}"
def forward(
self, x: torch.Tensor, mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
return pad_masked_sequence(x, mask, self.batch_first, self.padding_value)
__call__ = proxy(forward)
@overload
def chunk_by_slices(
x: torch.Tensor,
slices: torch.Tensor,
lens: Optional[torch.Tensor] = None,
mode: PadMode = "constant",
value: float = config.DEFT_PAD_VALUE,
) -> Tuple[torch.Tensor, torch.Tensor]:
...
[docs]
@script
@functional_wrapper("ChunkBySlices")
def chunk_by_slices(
x: torch.Tensor,
slices: torch.Tensor,
lens: Optional[torch.Tensor] = None,
mode: str = "constant",
value: float = config.DEFT_PAD_VALUE,
) -> Tuple[torch.Tensor, torch.Tensor]:
if x.ndim < 2:
raise RuntimeError(f"Expected x to be at least 2-dimensional; got {x.ndim}")
N, T = x.size(0), x.size(1)
if not N * T:
return x.new_empty(x.shape), slices.new_zeros((N,))
rest = x.shape[2:]
x = x.unsqueeze(-1).flatten(2)
device = x.device
if lens is None:
lens = torch.full((1,), T, dtype=torch.long, device=device).expand(N)
elif lens.shape != (N,):
raise RuntimeError(f"Expected lens to be of shape ({N,}); got {lens.shape}")
F = x.size(2)
start, end = slices[..., 0].contiguous(), slices[..., 1].contiguous()
chunk_lens = (end - start).clamp_min_(0)
empty = chunk_lens == 0
left_pad = (-start).clamp_min_(0).masked_fill_(empty, 0)
right_pad = (end - lens).clamp_min_(0).masked_fill_(empty, 0)
start_ = start.clamp_min(0)
end_ = torch.min(end, lens)
slice_lens = (end_ - start_).clamp_min(0)
left_buf, right_buf = _get_padding_buffers(x, lens, left_pad, right_pad, mode)
Tp = int(
torch.max(torch.max(left_pad.max(), chunk_lens.max()), right_pad.max()).item()
)
arange = torch.arange(max(T, Tp), device=device)
slice_mask = (
((start.unsqueeze(1) <= arange[:T]) & (end_.unsqueeze(1) > arange[:T]))
.unsqueeze(-1)
.expand(N, T, F)
)
x = x.masked_select(slice_mask)
left_mask = (left_pad.unsqueeze(1) > arange[:Tp]).unsqueeze(2).expand(N, Tp, F)
mid_mask = (
((left_pad + slice_lens).unsqueeze(1) > arange[:Tp])
.unsqueeze(2)
.expand(N, Tp, F)
)
chunks = x.new_full((N, Tp, F), value)
if mode != "constant":
chunks = chunks.masked_scatter(left_mask, left_buf)
right_mask = (
((left_pad + slice_lens + right_pad).unsqueeze(1) > arange[:Tp])
.unsqueeze(2)
.expand(N, Tp, F)
)
right_mask = right_mask & ~mid_mask
chunks = chunks.masked_scatter(right_mask, right_buf)
if mode == "reflect":
# we have to do some extra work for a special case. When the start and
# end indices are completely contained in the right padding, the slice
# may start at an offset within the padding. If so, we want to move the
# start of the slice within the padding to the start of the sequence
offset = (start_ - lens).clamp_min_(0)
keep = (offset > 0).view(N, 1, 1)
right_pad -= offset
right_mask = (
right_mask
& (
((left_pad + slice_lens + offset).unsqueeze(1) <= arange[:Tp])
.unsqueeze(2)
.expand(N, Tp, F)
)
& keep
)
right_buf = chunks[right_mask]
right_mask = (
(right_pad.unsqueeze(1) > arange[:Tp]).unsqueeze(2).expand(N, Tp, F)
) & keep
chunks = chunks.masked_scatter(right_mask & ~mid_mask, right_buf)
chunks = chunks.masked_scatter(mid_mask & ~left_mask, x)
return chunks.view((N, Tp) + rest), chunk_lens
[docs]
class ChunkBySlices(torch.nn.Module):
"""Chunk input using slices, padding where necessary
Parameters
----------
mode
How to pad slices that go beyond the sequence lengths. See :class:`PadVariable`
for more information on the modes.
value
The value to pad with when ``mode == 'constant'``.
Call Parameters
---------------
x : torch.Tensor
A tensor of shape ``(N, T, *)`` where ``N`` is the batch index and ``T`` is
the sequence index.
slices : torch.Tensor
A long tensor of shape ``(N, 2)`` containing pairs ``start, end``, where
`start` and `end` are the start (inclusive) and end (exclusive) indices,
respectively. Any slices exceeding segment boundaries will be padded according
to the `mode` specified.
lens : torch.Tensor, optional
An optional long tensor of shape ``(N,)`` specifying the sequence lengths. Only
the values in the range ``x[n, :lens[n]]`` are considered part of the sequence
of batch element ``n``. If unspecified, all sequences of `x` are assumed to be
of length ``T``.
Returns
-------
chunked : torch.Tensor
A tensor of shape ``(N, T', *)`` of chunks of `x`. Besides, ``T'``, `chunked`
matches the shape of `x`.`
chunked_lens : torch.Tensor
A long tensor of shape ``(N,)`` with the same interpretation as `lens`, but
for `chunked` instead.
Warnings
--------
Negative indices in slices in Python are usually interpreted as an offset left from
the end of the sequence. Here, however, negative indices indicate an offset left
from the start of the sequence. Those values will be interpreted as padding and be
added to the chunk.
See Also
--------
PadVariable
For more details on how padding works.
SliceSpectData
Can be used to determine `slices` for :class:`SpectDataSet` features. In
this case, ``x = x[sources]`` and ``lens = lens[sources]`` should be passed
to this module (using the return value `sources` from :class:`SliceSpectData`).
ChunkTokenSequenceBySlices
A similar purpose, but specifically for token sequences from a
:class:`SpectDataSet`.
"""
__constants__ = "mode", "value"
mode: str
value: float
def __init__(
self, mode: PadMode = "constant", value: float = config.DEFT_PAD_VALUE,
) -> None:
mode = argcheck.is_in(mode, get_args(PadMode), "mode")
value = argcheck.is_float(value, "value")
super().__init__()
self.mode, self.value = mode, value
def extra_repr(self) -> str:
s = f"mode={self.mode}"
if self.mode == "constant":
s += f", value={self.value}"
return s
def forward(
self,
x: torch.Tensor,
slices: torch.Tensor,
lens: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
return chunk_by_slices(x, slices, lens, self.mode, self.value)
__call__ = proxy(forward)