# Copyright 2022 Sean Robertson
#
# Code for polyharmonic_spline is converted from tensorflow code
# https://github.com/tensorflow/addons/blob/v0.11.2/tensorflow_addons/image/interpolate_spline.py
# code for sparse_image_warp is derived from tensorflow code, though it's not identical
# https://github.com/tensorflow/addons/blob/v0.11.2/tensorflow_addons/image/sparse_image_warp.py
#
# Which are also Apache 2.0 Licensed:
#
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Optional, Tuple, Union, overload
from typing_extensions import Literal, get_args
import torch
from . import argcheck
from ._pad import pad_variable
from ._compat import meshgrid, script, linalg_solve
from ._wrappers import functional_wrapper, proxy
DenseInterpolationMode = Literal["bilinear", "nearest"]
DensePaddingMode = Literal["border", "zero", "reflection"]
SparseIndexing = Literal["hw", "wh"]
RandomShiftMode = Literal["reflect", "constant", "replicate"]
@script
def _get_tensor_eps(
x: torch.Tensor,
eps16: float = torch.finfo(torch.float16).eps,
eps32: float = torch.finfo(torch.float32).eps,
eps64: float = torch.finfo(torch.float64).eps,
) -> float:
if x.dtype == torch.float16:
return eps16
elif x.dtype == torch.float32:
return eps32
elif x.dtype == torch.float64:
return eps64
else:
raise RuntimeError(f"Expected x to be floating-point, got {x.dtype}")
@script
def _phi(r: torch.Tensor, k: int) -> torch.Tensor:
if k % 2:
return r ** k
else:
return (r ** k) * (torch.clamp(r, min=_get_tensor_eps(r))).log()
@script
def _apply_interpolation(
w: torch.Tensor, v: torch.Tensor, c: torch.Tensor, x: torch.Tensor, k: int
) -> torch.Tensor:
r = torch.cdist(x, c) # (N, Q, T)
phi_r = _phi(r, k) # (N, Q, T)
phi_r_w = torch.bmm(phi_r, w) # (N, Q, O)
x1 = torch.cat([x, torch.ones_like(x[..., :1])], 2) # (N, Q, I+1)
x1_v = torch.bmm(x1, v) # (N, Q, O)
return phi_r_w + x1_v
@script
def _solve_interpolation(
c: torch.Tensor, f: torch.Tensor, k: int, reg: float, full: bool
) -> Tuple[torch.Tensor, torch.Tensor]:
# based on
# https://mathematica.stackexchange.com/questions/65763/understanding-polyharmonic-splines
# Symbol map (theirs => ours)
# x,y => c (N, T, I)
# A => A (N, T, T)
# W => B (N, T, I+1)
# v => w (N, T, O)
# bb => v (N, I+1, O)
# wa => f (N, T, O)
r_cc = torch.cdist(c, c) # (N, T, T)
A = _phi(r_cc, k) # (N, T, T)
if reg > 0.0:
A = A + torch.eye(A.shape[1], dtype=A.dtype, device=A.device).unsqueeze(0) * reg
B = torch.cat([c, torch.ones_like(c[..., :1])], 2) # (N, T, I+1)
if full:
# full matrix method (TF)
ABt = torch.cat([A, B.transpose(1, 2)], 1) # (N, T+I+1, T)
zeros = torch.zeros(
(B.shape[0], B.shape[2], B.shape[2]), device=B.device, dtype=B.dtype
)
B0 = torch.cat([B, zeros], 1,) # (N, T+I+1, I+1)
ABtB0 = torch.cat([ABt, B0], 2) # (N, T+I+1, T+I+1)
zeros = torch.zeros(
(B.shape[0], B.shape[2], f.shape[2]), device=f.device, dtype=f.dtype
)
f0 = torch.cat([f, zeros], 1,) # (N, T+I+1, O)
wv = linalg_solve(ABtB0, f0)
w, v = wv[:, : B.shape[1]], wv[:, B.shape[1] :]
else:
# block decomposition
Ainv = torch.inverse(A) # (N, T, T)
Ainv_f = torch.bmm(Ainv, f) # (N, T, O)
Ainv_B = torch.bmm(Ainv, B) # (N, T, I+1)
Bt = B.transpose(1, 2) # (N, I+1, T)
Bt_Ainv_B = torch.bmm(Bt, Ainv_B) # (N, I+1, I+1)
Bt_Ainv_f = torch.bmm(Bt, Ainv_f) # (N, I+1, O)
v = linalg_solve(Bt_Ainv_B, Bt_Ainv_f)
Ainv_B_v = torch.bmm(Ainv_B, v) # (N, T, O)
w = Ainv_f - Ainv_B_v # (N, T, O)
# orthagonality constraints
# assert torch.allclose(w.sum(1), torch.tensor(0.0, device=w.device)), w.sum()
# assert torch.allclose(
# torch.bmm(w.transpose(1, 2), c), torch.tensor(0.0, device=w.device)
# ), torch.bmm(w.transpose(1, 2), c).sum()
return w, v
[docs]
@functional_wrapper("PolyharmonicSpline")
def polyharmonic_spline(
train_points: torch.Tensor,
train_values: torch.Tensor,
query_points: torch.Tensor,
order: int,
regularization_weight: float = 0.0,
full_matrix: bool = True,
) -> torch.Tensor:
train_points = train_points.float()
query_points = query_points.float()
w, v = _solve_interpolation(
train_points, train_values, order, regularization_weight, full_matrix
)
query_values = _apply_interpolation(w, v, train_points, query_points, order)
return query_values
[docs]
class PolyharmonicSpline(torch.nn.Module):
"""Guess values at query points using a learned polyharmonic spline
A spline estimates a function ``f : points -> values`` from a fixed number of
training points/knots and the values of ``f`` at those points. It does that by
solving a series of piecewise linear equations between knots such that the values at
the knots match the given values (and some additional constraints depending on the
spline).
This module is based on the `interpolate_spline
<https://www.tensorflow.org/addons/api_docs/python/tfa/image/interpolate_spline>`__
function from Tensorflow, which implements a `Polyharmonic Spline
<https://en.wikipedia.org/wiki/Polyharmonic_spline>`__. For technical details,
consult the TF documentation.
Parameters
----------
order
Order of the spline (> 0). 1 = linear. 2 = thin plate spline.
regularization_weight
Weight placed on the regularization term. See TF for more info.
full_matrix
Whether to solve linear equations via a full concatenated matrix or a block
decomposition. Setting to :obj:`True` better matches TF and appears to slightly
improve numerical accuracy at the cost of twice the run time and more memory
usage.
Call Parameters
---------------
train_points : torch.Tensor
A tensor of shape ``(N, T, I)`` representing the training points/knots for ``N``
different functions. ``N`` is the batch dimension, ``T`` is the number of
training points, and ``I`` is the size of the vector input to ``f``.
train_values : torch.Tensor
A tensor of shape ``(N, T, O)`` of ``f`` evaluated on `train_points`. ``O`` is
the size of the output vector of ``f``.
query_points : torch.Tensor
A tensor of shape ``(N, Q, I)`` representing the points you wish to have
estimates for.
Returns
-------
query_values : torch.Tensor
A tensor of shape ``(N, Q, O)`` consisting of the values estimated for
`query_points`.
Raises
------
RuntimeError
This module can return a :class`RuntimeError` when no unique spline can be
estimated. In general, the spline will require at least ``I+1`` non-degenerate
points (linearly independent). See the Wikipedia entry on splnes for more info.
"""
__constants__ = "order", "regularization_weight", "full_matrix"
order: int
regularization_weight: float
full_matrix: bool
def __init__(
self, order: int, regularization_weight: float = 0.0, full_matrix: bool = True
):
order = argcheck.is_posi(order, "order")
regularization_weight = argcheck.is_float(
regularization_weight, "regularization_weight"
)
full_matrix = argcheck.is_bool(full_matrix, "full_matrix")
super().__init__()
self.order = order
self.regularization_weight = regularization_weight
self.full_matrix = full_matrix
def forward(
self,
train_points: torch.Tensor,
train_values: torch.Tensor,
query_points: torch.Tensor,
) -> torch.Tensor:
return polyharmonic_spline(
train_points,
train_values,
query_points,
self.order,
self.regularization_weight,
self.full_matrix,
)
__call__ = proxy(forward)
@script
def _deterimine_pinned_points(k: int, sizes: torch.Tensor) -> torch.Tensor:
w_max = (sizes[:, :1] - 1).expand(-1, k + 1) # (N, k+1)
h_max = (sizes[:, 1:] - 1).expand(-1, k + 1) # (N, k+1)
range_ = torch.linspace(
0.0, 1.0, k + 1, dtype=sizes.dtype, device=sizes.device
) # (k+1,)
w_range = w_max * range_ # (N, k+1)
h_range = h_max * range_ # (N, k+1)
zeros = torch.zeros_like(w_range) # (N, k+1)
# (0, 0) -> (W - 1, 0) inclusive
bottom_edge = torch.stack([w_range, zeros], 2) # (N, k+1, 2)
# (0, 0) -> (0, H - 1) exclusive
left_edge = torch.stack([zeros[:, 1:-1], h_range[:, 1:-1]], 2) # (N, k-1, 2)
# (0, H - 1) -> (W - 1, H - 1) inclusive
top_edge = torch.stack([w_range, h_max], 2) # (N, k+1, 2)
# (W - 1, 0) -> (W - 1, H - 1) exclusive
right_edge = torch.stack([w_max[:, 1:-1], h_range[:, 1:-1]], 2) # (N, k-1, 2)
return torch.cat([bottom_edge, left_edge, top_edge, right_edge], 1) # (N, 4k, 2)
[docs]
@script
@functional_wrapper("Warp1DGrid")
def warp_1d_grid(
src: torch.Tensor,
flow: torch.Tensor,
lengths: torch.Tensor,
max_length: Optional[int] = None,
interpolation_order: int = 1,
) -> torch.Tensor:
device = src.device
N = src.shape[0]
if max_length is None:
T = int(math.ceil(lengths.max().item())) if lengths.numel() else 0
else:
T = max_length
src, flow, lengths = src.float(), flow.float(), lengths.float()
eps = _get_tensor_eps(src) # the epsilon avoids singular matrices
src = torch.min(src, lengths - 1).clamp_min(0)
dst = torch.min(src + flow, lengths - 1).clamp_min(0)
src = (2.0 * src + 1.0) / T - 1.0
dst = (2.0 * dst + 1.0) / T - 1.0
lowers = torch.full((N,), 1 / T - 1 - eps, dtype=torch.float, device=device)
uppers = (2 * lengths - 1) / T - 1.0 + eps
src = torch.stack([lowers, src, uppers], 1) # (N, 3)
dst = torch.stack([lowers, dst, uppers], 1) # (N, 3)
# sparse_grid = (2.0 * src + 1.0) / T - 1.0 # (N,3)
t = (2.0 * torch.arange(T, device=device) + 1.0) / T - 1.0
grid = polyharmonic_spline(
dst.unsqueeze(-1), # dst (N, 3, 1)
src.unsqueeze(-1), # (N, 3, 1)
t.unsqueeze(0).expand(N, T).unsqueeze(-1), # (N, T, 1)
interpolation_order,
).squeeze(
-1
) # (N, T)
return grid
[docs]
class Warp1DGrid(torch.nn.Module):
"""Interpolate grid values for a dimension of a grid_sample
This module determines a grid along a single dimension of a signal, image, volume,
whatever.
Parameters
----------
max_length
A maximum length to which the grid will be padded. If unspecified, it will be
taken to be ``lengths.max().ceil()``. If `grid` is being plugged in to
:func:`grid_sample`, ensure `max_length` matches the size of the dimension of
the image being warped.
interpolation_order
The degree of the spline used ot interpolate the grid.
Call Parameters
---------------
src : torch.Tensor
A tensor of shape ``(N,)`` containing the source points.
flow : torch.Tensor
A tensor of shape ``(N,)`` contianing the corresponding flow fields for `src`.
lengths : torch.Tensor
A tensor of shape ``(N,)`` specifying the number of valid indices along the
dimension in question.
Returns
-------
grid : torch.Tensor
A tensor of shape ``(N, max_length)`` which provides coodinates for one
dimension of the grid passed to :func:`torch.nn.functional.grid_sample`. See the
example below.
Notes
-----
The return value `grid` assumes `align_corners` has been set to :obj:`False` in
:func:`grid_sample`.
The values in `grid` depend on the value of `max_length`. `grid` does not contain
absolute pixel indices, instead mapping the range ``[0, max_length - 1]`` to the
real values in ``[-1, 1]``. Therefore, unless `max_length` is set to some fixed
value, the ``n``-th batch element ``grid[n]`` can differ according to the remaining
values in `length`. However, the ``n``-th batched image passed to
:func:`grid_sample` should still be warped in a way (roughly) agnostic to
surrounding batched images.
"""
__constants__ = "max_length", "interpolation_order"
interpolation_order: int
max_length: Optional[int]
def __init__(self, max_length: Optional[int] = None, interpolation_order: int = 1):
if max_length is not None:
max_length = argcheck.is_nonnegi(max_length, "max_length")
interpolation_order = argcheck.is_posi(
interpolation_order, "interpolation_order"
)
super().__init__()
self.max_length = max_length
self.interpolation_order = interpolation_order
def extra_repr(self) -> str:
s = f"interpolation_order={self.interpolation_order}"
if self.max_length is not None:
s = f"max_length={self.max_length}, " + s
return s
def forward(
self, src: torch.Tensor, flow: torch.Tensor, lengths: torch.Tensor
) -> torch.Tensor:
return warp_1d_grid(
src, flow, lengths, self.max_length, self.interpolation_order
)
@overload
def dense_image_warp(
image: torch.Tensor,
flow: torch.Tensor,
indexing: str = "hw",
mode: DenseInterpolationMode = "bilinear",
padding_mode: DensePaddingMode = "border",
) -> torch.Tensor:
...
[docs]
@script
@functional_wrapper("DenseImageWarp")
def dense_image_warp(
image: torch.Tensor,
flow: torch.Tensor,
indexing: str = "hw",
mode: str = "bilinear",
padding_mode: str = "border",
) -> torch.Tensor:
# from tfa.image.dense_image_warp
# output[n, c, h, w] = image[n, c, h - flow[n, h, w, 0], w - flow[n, h, w, 1]]
# outside of image uses border
# from torch.nn.functional.grid_sample
# output[n, c, h, w] = image[n, c, h, f(grid[n, h, w, 1], H),
# f(grid[n, h, w, 0], W)]
# where
# f(x, X) = ((x + 1) * X - 1) / 2
# therefore
# output[n, c, h, w] = image[n, c, ((grid[n, h, w, 1] + 1) * H - 1) / 2,
# ((grid[n, h, w, 0] + 1) * W - 1) / 2]
#
# ((grid[n, h, w, 1] + 1) * H - 1) / 2 = h - flow[n, h, w, 0]
# grid[n, h, w, 1] = (2 * h - 2 * flow[n, h, w, 0] + 1) / H - 1
# likewise
# grid[n, h, w, 0] = (2 * w - 2 * flow[n, h, w, 1] + 1) / W - 1
flow = flow.float()
N, C, H, W = image.shape
h = torch.arange(H, dtype=image.dtype, device=image.device) # (H,)
w = torch.arange(W, dtype=image.dtype, device=image.device) # (W,)
h, w = meshgrid(h, w) # (H, W), (H, W)
if indexing == "hw":
# grid_sample uses wh sampling, so we flip both the flow and hw along final axis
hw = torch.stack((w, h), 2).unsqueeze(0) # (1, H, W, 2)
flow = flow.flip(-1)
elif indexing == "wh":
hw = torch.stack((w, h), 2).unsqueeze(0) # (1, H, W, 2)
else:
raise ValueError("Invalid indexing! must be one of 'wh' or 'hw'")
HW = torch.tensor([[[[W, H]]]], dtype=image.dtype, device=image.device) # (1,1,1,2)
grid = (2 * hw - 2 * flow + 1.0) / HW - 1.0
return torch.nn.functional.grid_sample(
image, grid, mode=mode, padding_mode=padding_mode, align_corners=False
)
[docs]
class DenseImageWarp(torch.nn.Module):
"""Warp an input image with per-pixel flow vectors
This reproduces the functionality of Tensorflow's `dense_image_warp
<https://www.tensorflow.org/addons/api_docs/python/tfa/image/dense_image_warp>`__,
except `image` is in ``NCHW`` order instead of ``NHWC`` order. It wraps
:func:`torch.nn.functional.grid_sample`.
Parameters
----------
indexing
If `indexing` is ``"hw"``, ``flow[..., 0] = h``, the height index, and
``flow[..., 1] = w`` is the width index. If ``"wh"``, ``flow[..., 0] = w``
and ``flow[..., 1] = h``. The default in TF is ``"hw"``, whereas torch's
`grid_sample` is ``"wh"``
mode
The method of interpolation. Either use bilinear interpolation or the nearest
pixel value. The TF default is ``"bilinear"``
padding_mode
Controls how points outside of the image boundaries are interpreted.
``"border"``: copy points at around the border of the image. ``"zero"``: use
zero-valued pixels. ``"reflection"``: reflect pixels into the image starting
from the boundaries.
Call Parameters
---------------
image : torch.Tensor
A tensor of shape ``(N, C, H, W)`` where ``N`` is the batch dimension, ``C`` is
the channel dimension, ``H`` is the height dimension, and ``W`` is the width
dimension.
flow : torch.Tensor
A ftensor of shape ``(N, H, W, 2)``.
Returns
-------
warped : torch.Tensor
A warped `image` of shape ``(N, C, H, W)`` such that::
warped[n, c, h, w] = image[n, c, h - flow[n, h, w, 0], w - flow[n, h, w, 1]]
If the reference indices ``h - ...`` and ``w - ...`` are not integers, the value
is interpolated from the neighboring pixel values.
Warning
-------
`flow` is not an optical flow. Please consult the TF documentation for more details.
"""
__constants__ = "indexing", "mode", "padding_mode"
indexing: str
mode: str
padding_mode: str
def __init__(
self,
indexing: SparseIndexing = "hw",
mode: DenseInterpolationMode = "bilinear",
padding_mode: DensePaddingMode = "border",
):
indexing = argcheck.is_in(indexing, get_args(SparseIndexing), "indexing")
mode = argcheck.is_in(mode, get_args(DenseInterpolationMode), "mode")
padding_mode = argcheck.is_in(
padding_mode, get_args(DensePaddingMode), "padding_mode"
)
super().__init__()
self.indexing, self.mode, self.padding_mode = indexing, mode, padding_mode
def forward(self, image: torch.Tensor, flow: torch.Tensor) -> torch.Tensor:
return dense_image_warp(
image, flow, self.indexing, self.mode, self.padding_mode
)
__call__ = proxy(forward)
# N.B. We do this ugly thing so that a trace can be aware of the returned type
# (rather than just "Any")
@script
def _sparse_image_warp_flow(
image,
source_points,
dest_points,
indexing: str,
field_interpolation_order: int,
field_regularization_weight: float,
field_full_matrix: bool,
pinned_boundary_points: int,
dense_interpolation_mode: str,
dense_padding_mode: str,
) -> Tuple[torch.Tensor, torch.Tensor]:
if indexing == "hw":
source_points = source_points.flip(-1)
dest_points = dest_points.flip(-1)
source_points = source_points.float()
dest_points = dest_points.float()
N, C, H, W = image.shape
WH = torch.tensor([[W, H]] * N, dtype=image.dtype, device=image.device)
M = source_points.shape[1]
if not M:
return (
image,
torch.zeros((N, H, W, 2), dtype=torch.float, device=image.device),
)
if pinned_boundary_points > 0:
pinned_points = _deterimine_pinned_points(pinned_boundary_points, WH)
source_points = torch.cat([source_points, pinned_points], 1) # (N,M',2)
dest_points = torch.cat([dest_points, pinned_points], 1) # (N,M+4k=M',2)
# now just pretend M' was M all along
H_range = torch.arange(H, dtype=image.dtype, device=image.device) # (H,)
W_range = torch.arange(W, dtype=image.dtype, device=image.device) # (W,)
h, w = meshgrid(H_range, W_range) # (H, W), (H, W)
query_points = torch.stack([w.flatten(), h.flatten()], 1) # (H * W, 2)
train_points = dest_points
train_values = dest_points - source_points
flow = polyharmonic_spline(
train_points,
train_values,
query_points.unsqueeze(0).expand(N, H * W, 2),
field_interpolation_order,
regularization_weight=field_regularization_weight,
full_matrix=field_full_matrix,
)
flow = flow.view(N, H, W, 2)
warped = dense_image_warp(
image,
flow,
indexing="wh",
mode=dense_interpolation_mode,
padding_mode=dense_padding_mode,
)
if indexing == "hw":
flow = flow.flip(-1)
return warped, flow
@script
def _sparse_image_warp_noflow(
image,
source_points,
dest_points,
indexing: str,
field_interpolation_order: int,
field_regularization_weight: float,
field_full_matrix: bool,
pinned_boundary_points: int,
dense_interpolation_mode: str,
dense_padding_mode: str,
) -> torch.Tensor:
# all our computations assume "wh" ordering, so we flip it here if necessary.
# Though unintuitive, we need this for our call to grid_sample
if indexing == "hw":
source_points = source_points.flip(-1)
dest_points = dest_points.flip(-1)
source_points = source_points.float()
dest_points = dest_points.float()
N, C, H, W = image.shape
WH = torch.tensor([[W, H]] * N, dtype=image.dtype, device=image.device)
M = source_points.shape[1]
if not M:
return image
if pinned_boundary_points > 0:
pinned_points = _deterimine_pinned_points(pinned_boundary_points, WH)
source_points = torch.cat([source_points, pinned_points], 1) # (N,M',2)
dest_points = torch.cat([dest_points, pinned_points], 1) # (N,M+4k=M',2)
# now just pretend M' was M all along
H_range = torch.arange(H, dtype=image.dtype, device=image.device) # (H,)
W_range = torch.arange(W, dtype=image.dtype, device=image.device) # (W,)
h, w = meshgrid(H_range, W_range) # (H, W), (H, W)
query_points = torch.stack([w.flatten(), h.flatten()], 1) # (H * W, 2)
# If we can return just the warped image, we can bypass our call to dense_image_warp
# by interpolating the 'grid' parameter of 'grid_sample' instead of the 'flow'
# parameter of 'dense_image_warp'
# coord = ((grid + 1) * size - 1) / 2
# grid = (2 coord + 1) / size - 1
train_points = dest_points # (N, M, 2)
train_values = (2.0 * source_points + 1.0) / WH.unsqueeze(1) - 1.0 # (N, M, 2)
grid = polyharmonic_spline(
train_points,
train_values,
query_points.unsqueeze(0).expand(N, H * W, 2),
field_interpolation_order,
regularization_weight=field_regularization_weight,
full_matrix=field_full_matrix,
)
grid = grid.view(N, H, W, 2)
warped = torch.nn.functional.grid_sample(
image,
grid,
mode=dense_interpolation_mode,
padding_mode=dense_padding_mode,
align_corners=False,
)
return warped
@overload
def sparse_image_warp(
image: torch.Tensor,
source_points: torch.Tensor,
dest_points: torch.Tensor,
indexing: SparseIndexing = "hw",
field_interpolation_order: int = 2,
field_regularization_weight: float = 0.0,
field_full_matrix: bool = True,
pinned_boundary_points: int = 0,
dense_interpolation_mode: DenseInterpolationMode = "bilinear",
dense_padding_mode: DensePaddingMode = "border",
include_flow: bool = True,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
...
[docs]
@functional_wrapper("SparseImageWarp")
def sparse_image_warp(
image: torch.Tensor,
source_points: torch.Tensor,
dest_points: torch.Tensor,
indexing: str = "hw",
field_interpolation_order: int = 2,
field_regularization_weight: float = 0.0,
field_full_matrix: bool = True,
pinned_boundary_points: int = 0,
dense_interpolation_mode: str = "bilinear",
dense_padding_mode: str = "border",
include_flow: bool = True,
) -> Any:
if include_flow:
return _sparse_image_warp_flow(
image,
source_points,
dest_points,
indexing,
field_interpolation_order,
field_regularization_weight,
field_full_matrix,
pinned_boundary_points,
dense_interpolation_mode,
dense_padding_mode,
)
else:
return _sparse_image_warp_noflow(
image,
source_points,
dest_points,
indexing,
field_interpolation_order,
field_regularization_weight,
field_full_matrix,
pinned_boundary_points,
dense_interpolation_mode,
dense_padding_mode,
)
[docs]
class SparseImageWarp(torch.nn.Module):
r"""Warp an image by specifying mappings between few control points
This module mirrors the behaviour of Tensorflow's `sparse_image_warp
<https://www.tensorflow.org/addons/api_docs/python/tfa/image/sparse_image_warp>`__,
except `image` is in ``NCHW`` order instead of ``NHWC`` order. For more details,
please consult their documentation.
Parameters
----------
indexing
If `indexing` is ``"hw"``, ``source_points[n, m, 0]`` and ``dest_points[n, m,
0]`` index the height dimension in `image` and `warped`, respectively, and
``source_points[n, m, 1]`` and ``dest_points[n, m, 1]`` the width dimension. If
`indexing` is ``"wh"``, the width dimension is the 0-index and height the 1.
field_interpolation_order
The order of the polyharmonic spline used to interpolate the rest of the points
from the control. See :func:`polyharmonic_spline` for more info.
field_regularization_weight
The regularization weight of the polyharmonic spline used to interpolate the
rest of the points from the control. See :func:`polyharmonic_spline` for more
info.
field_full_matrix
Determines the method of calculating the polyharmonic spline used to interpolate
the rest of the points from the control. See :func:`polyharmonic_spline` for
more info.
pinned_boundary_points
Dictates whether and how many points along the boundary of `image` are mapped
identically to points in `warped`. This keeps the boundary of the `image` from
being pulled into the interior of `warped`. When :obj:`0`, no points are added.
When :obj:`1`, four points are added, one in each corner of the image. When ``k
> 2``, one point in each corner of the image is added, then ``k - 1``
equidistant points along each of the four edges, totaling ``4 * k`` points.
dense_interpolation_mode
The method with which partial indices in the derived mapping are interpolated.
See :func:`dense_image_warp` for more info.
dense_padding_mode
What to do when points in the derived mapping fall outside of the boundaries.
See :func:`dense_image_warp` for more info.
include_flow
If :obj:`True`, include the flow field `flow` interpolated from the control
points in the return value.
Call Parameters
---------------
image : torch.Tensor
A source image of shape ``(N, C, H, W)`` where ``N`` is the batch dimension,
``C`` the channel dimension, ``H`` the image height, and ``W`` the image width.
source_points, dest_points : torch.Tensor
Tensors of shape ``(N, M, 2)``, where ``M`` is the number of control points.
Returns
-------
warped : torch.Tensor or tuple of torch.Tensor
If `include_flow` is :obj:`False`, `warped` is a warped `image` of shape ``(N,
C, H, W)``. The point ``source_points[n, m, :]`` in `image` will be mapped to
``dest_points[n, m, :]`` in `warped`. If `include_flow` is :obj:`True`, `warped`
is a pair of tensors ``warped_, flow`` where `warped_` has the same definition
as `warped` when `include_flow` is :obj:`True` and `flow` is a tensor of shape
``(N, H, W, 2)``. ``flow[n, h, w, :]`` is the flow for coordinates ``h, w`` in
whatever order was specified by `indexing`. See :class:`DenseImageWarp` for more
details about `flow`.
Warnings
--------
When this module is scripted, its return type will be :class:`typing.Any`. This
reflects the fact that either `warped` is returned on its own (a tensor) or both
`warped_` and `flow` (a tuple). Use :func:`torch.jit.isinstance` for type refinement
in subsequent scripting. Tracing will infer the correct type.
"""
__constants__ = [
"indexing",
"field_interpolation_order",
"field_regularization_weight",
"field_full_matrix",
"pinned_boundary_points",
"dense_interpolation_mode",
"dense_padding_mode",
"include_flow",
]
field_interpolation_order: int
field_regularization_weight: float
field_full_matrix: bool
pinned_boundary_points: int
dense_interpolation_mode: str
dense_padding_mode: str
include_flow: bool
def __init__(
self,
indexing: SparseIndexing = "hw",
field_interpolation_order: int = 2,
field_regularization_weight: float = 0.0,
field_full_matrix: bool = True,
pinned_boundary_points: int = 0,
dense_interpolation_mode: DenseInterpolationMode = "bilinear",
dense_padding_mode: DensePaddingMode = "border",
include_flow: bool = True,
):
indexing = argcheck.is_in(indexing, get_args(SparseIndexing), "indexing")
field_interpolation_order = argcheck.is_posi(
field_interpolation_order, "field_interpolation_order"
)
field_regularization_weight = argcheck.is_float(
field_regularization_weight, "field_regularization_weight"
)
field_full_matrix = argcheck.is_bool(field_full_matrix, "field_ful_matrix")
pinned_boundary_points = argcheck.is_nonnegi(
pinned_boundary_points, "pinned_boundary_points"
)
dense_interpolation_mode = argcheck.is_in(
dense_interpolation_mode,
get_args(DenseInterpolationMode),
"dense_interpolation_mode",
)
dense_padding_mode = argcheck.is_in(
dense_padding_mode, get_args(DensePaddingMode), dense_padding_mode
)
include_flow = argcheck.is_bool(include_flow, "include_flow")
super().__init__()
self.indexing = indexing
self.field_interpolation_order = field_interpolation_order
self.field_regularization_weight = field_regularization_weight
self.field_full_matrix = field_full_matrix
self.pinned_boundary_points = pinned_boundary_points
self.dense_interpolation_mode = dense_interpolation_mode
self.dense_padding_mode = dense_padding_mode
self.include_flow = include_flow
def extra_repr(self) -> str:
return ", ".join(f"{x}={getattr(self, x)}" for x in self.__constants__)
@overload
def forward(
self,
image: torch.Tensor,
source_points: torch.Tensor,
dest_points: torch.Tensor,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
pass
def forward(
self,
image: torch.Tensor,
source_points: torch.Tensor,
dest_points: torch.Tensor,
) -> Any:
return sparse_image_warp(
image,
source_points,
dest_points,
self.indexing,
self.field_interpolation_order,
self.field_regularization_weight,
self.field_full_matrix,
self.pinned_boundary_points,
self.dense_interpolation_mode,
self.dense_padding_mode,
self.include_flow,
)
__call__ = proxy(forward)
[docs]
@script
@functional_wrapper("RandomShift")
def random_shift(
input: torch.Tensor,
in_lens: torch.Tensor,
prop: Tuple[float, float],
mode: str,
value: float,
training: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
if input.dim() < 2:
raise RuntimeError(f"input must be at least 2 dimensional")
if in_lens.dim() != 1 or in_lens.size(0) != input.size(0):
raise RuntimeError(
f"For input of shape {input.shape}, expected in_lens to be of shape "
f"({input.size(0)}), got {in_lens.shape}"
)
if training:
in_lens_ = in_lens.float()
pad = torch.stack([prop[0] * in_lens_, prop[1] * in_lens_])
pad *= torch.rand_like(pad)
pad = pad.long()
out_lens = in_lens + pad.sum(0)
return pad_variable(input, in_lens, pad, mode, value), out_lens
else:
return input, in_lens
[docs]
class RandomShift(torch.nn.Module):
"""Pad to the left and right of each sequence by a random amount
This layer is intended for training models which are robust to small shifts in some
variable-length sequence dimension (e.g. speech recognition). It pads each input
sequence with some number of elements at its beginning and end. The number of
elements is randomly chosen but bounded above by some proportion of the input length
specified by the user.
The amount of padding is dictated by the parameter `prop` this layer is initialized
with. A proportion is a non-negative float dictating the maximum ratio of the
original sequence length which may be padded, exclusive. `prop` can be a pair
``left, right`` for separate ratios of padding at the beginning and end of a
sequence, or just one float if the proportions are the same. For example,
``prop=0.5`` of a sequence of length ``10`` could result in a sequence of length
between ``10`` and ``18`` inclusive since each side of the sequence could be padded
with ``0-4`` elements (``0.5 * 10 = 5`` is an exclusive bound).
Parameters
----------
prop
mode
The method with which to pad the input sequence.
value
The constant with which to pad the sequence if `mode` is set to
:obj:`'constant'`.
Call Parameters
---------------
input : torch.Tensor
A tensor of shape ``(N, T, *)`` where ``N`` is the batch dimension and ``T`` is
the sequence dimension; `in_lens` is a long tensor of shape ``(N,)``.
in_lens : torch.Tensor
A tensor of shape ``(N,)`` containing the lengths of the sequences in `in_`.
For batch element ``n``, only the values ``input[n, in_lens[n]]`` are valid.
Returns
-------
out : torch.Tensor
A tensor of the same type as ``input`` of shape ``(N, T', *)``.
out_lens : torch.Tensor
A tensor of shape ``(N,)`` containing the lengths of the sequences in `out`.
Raises
------
NotImplementedError
On initialization if `mode` is :obj:`'reflect'` and a value in `prop` exceeds
``1.0``. Reflection currently requires the amount of padding does not exceed
the original sequence length.
Notes
-----
Padding is only applied if this layer is in training mode. If testing, ``out,
out_lens = input, in_lens``.
See Also
--------
pydrobert.torch.util.pad_variable
For more details on the different types of padding. Note the default `mode` is
different between this and the function.
"""
__constants__ = ["prop", "mode", "value"]
prop: Tuple[float, float]
mode: str
value: float
def __init__(
self,
prop: Union[float, Tuple[float, float]],
mode: RandomShiftMode = "reflect",
value: float = 0.0,
):
try:
prop = (argcheck.is_float(prop, "prop"), float(prop))
except TypeError:
prop = tuple(prop)
if len(prop) != 2:
raise ValueError(
f"prop must be a single or pair of floating points, got '{prop}'"
)
if prop[0] < 0.0 or prop[1] < 0.0:
raise ValueError("prop values must be non-negative")
mode = argcheck.is_in(mode, get_args(RandomShiftMode), "mode")
if mode == "reflect":
if prop[0] > 1.0 or prop[1] > 1.0:
raise NotImplementedError(
"if 'mode' is 'reflect', values in 'prop' must be <= 1"
)
value = argcheck.is_float(value, "value")
super().__init__()
self.mode, self.prop, self.value = mode, prop, value
def extra_repr(self) -> str:
return f"prop={self.prop}, mode={self.mode}, value={self.value}"
def reset_parameters(self) -> None:
pass
def forward(
self, input: torch.Tensor, in_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
return random_shift(
input, in_lens, self.prop, self.mode, self.value, self.training
)
__call__ = proxy(forward)
@script
def _spec_augment_check_input(
feats: torch.Tensor, lengths: Optional[torch.Tensor] = None
):
if feats.dim() != 3:
raise RuntimeError(
f"Expected feats to have three dimensions, got {feats.dim()}"
)
N, T, _ = feats.shape
if lengths is not None:
if lengths.dim() != 1:
raise RuntimeError(
f"Expected lengths to be one dimensional, got {lengths.dim()}"
)
if lengths.size(0) != N:
raise RuntimeError(
f"Batch dimension of feats ({N}) and lengths ({lengths.size(0)}) "
"do not match"
)
if not torch.all((lengths <= T) & (lengths > 0)):
raise RuntimeError(f"values of lengths must be between (1, {T})")
SpecAugmentParams = Tuple[
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
]
[docs]
@script
@functional_wrapper("SpecAugment.draw_parameters")
def spec_augment_draw_parameters(
feats: torch.Tensor,
max_time_warp: float,
max_freq_warp: float,
max_time_mask: int,
max_freq_mask: int,
max_time_mask_proportion: float,
num_time_mask: int,
num_time_mask_proportion: float,
num_freq_mask: int,
lengths: Optional[torch.Tensor] = None,
) -> SpecAugmentParams:
_spec_augment_check_input(feats, lengths)
N, T, F = feats.shape
device = feats.device
eps = _get_tensor_eps(feats)
omeps = 1 - eps
if lengths is None:
lengths = torch.full((N,), T, dtype=torch.float, device=device)
else:
lengths = lengths.to(device).float()
# note that order matters slightly in whether we draw widths or positions first.
# The paper specifies that position is drawn first for warps, whereas widths
# are drawn first for masks
if max_time_warp:
# we want the range (W, length - W) exclusive to be where w_0 can come
# from. If W >= length / 2, this is impossible. Rather than giving up,
# we limit the maximum length to W < length / 2
# N.B. We don't worry about going outside the valid range by a bit b/c
# warp_1d_grid clamps.
W = (lengths / 2 - eps).clamp(0, max_time_warp)
w_0 = torch.rand((N,), device=device) * (lengths - 2 * W) + W
w = torch.rand([N], device=device) * (2 * W) - W
else:
w_0 = w = torch.empty(0)
if max_freq_warp:
V = min(max(F / 2 - eps, 0), max_freq_warp)
v_0 = torch.rand([N], device=device) * (F - 2 * V) + V
v = torch.rand([N], device=device) * (2 * V) - V
else:
v_0 = v = torch.empty(0)
if (
max_time_mask
and max_time_mask_proportion
and num_time_mask
and num_time_mask_proportion
):
max_ = (
torch.clamp(lengths * max_time_mask_proportion, max=max_time_mask,)
.floor()
.to(device)
)
nums_ = (
torch.clamp(lengths * num_time_mask_proportion, max=num_time_mask,)
.floor()
.to(device)
)
t = (
(
torch.rand([N, num_time_mask], device=device)
* (max_ + omeps).unsqueeze(1)
)
.long()
.masked_fill(
nums_.unsqueeze(1)
<= torch.arange(num_time_mask, dtype=lengths.dtype, device=device),
0,
)
)
t_0 = (
torch.rand([N, num_time_mask], device=device)
* (lengths.unsqueeze(1) - t + omeps)
).long()
else:
t = t_0 = torch.empty(0)
if max_freq_mask and num_freq_mask:
max_ = min(max_freq_mask, F)
f = (torch.rand([N, num_freq_mask], device=device) * (max_ + omeps)).long()
f_0 = (torch.rand([N, num_freq_mask], device=device) * (F - f + omeps)).long()
else:
f = f_0 = torch.empty(0)
return w_0, w, v_0, v, t_0, t, f_0, f
[docs]
@script
@functional_wrapper("SpecAugment.apply_parameters")
def spec_augment_apply_parameters(
feats: torch.Tensor,
params: SpecAugmentParams,
interpolation_order: int,
lengths: Optional[torch.Tensor] = None,
) -> torch.Tensor:
_spec_augment_check_input(feats, lengths)
N, T, F = feats.shape
device = feats.device
if lengths is None:
lengths = torch.full((N,), T, dtype=torch.long, device=device)
lengths = lengths.to(feats.dtype)
w_0, w, v_0, v, t_0, t, f_0, f = params
new_feats = feats
time_grid: Optional[torch.Tensor] = None
freq_grid: Optional[torch.Tensor] = None
do_warp = False
if w_0 is not None and w_0.numel() and w is not None and w.numel():
time_grid = warp_1d_grid(w_0, w, lengths, T, interpolation_order)
do_warp = True
if v_0 is not None and v_0.numel() and v is not None and v.numel():
freq_grid = warp_1d_grid(
v_0,
v,
torch.full((N,), F, dtype=torch.long, device=device),
F,
interpolation_order,
)
do_warp = True
if do_warp:
if time_grid is None:
time_grid = torch.arange(T, device=device, dtype=torch.float)
time_grid = (2 * time_grid + 1) / T - 1
time_grid = time_grid.unsqueeze(0).expand(N, T)
if freq_grid is None:
freq_grid = torch.arange(F, device=device, dtype=torch.float)
freq_grid = (2 * freq_grid + 1) / F - 1
freq_grid = freq_grid.unsqueeze(0).expand(N, F)
time_grid = time_grid.unsqueeze(2).expand(N, T, F)
freq_grid = freq_grid.unsqueeze(1).expand(N, T, F)
# note: grid coordinate are (freq, time) rather than (time, freq)
grid = torch.stack([freq_grid, time_grid], 3) # (N, T, F, 2)
new_feats = torch.nn.functional.grid_sample(
new_feats.unsqueeze(1),
grid,
mode="bilinear",
padding_mode="border",
align_corners=False,
).squeeze(1)
tmask: Optional[torch.Tensor] = None
fmask: Optional[torch.Tensor] = None
if t_0 is not None and t_0.numel() and t is not None and t.numel():
tmask = torch.arange(T, device=device).unsqueeze(0).unsqueeze(2) # (1, T,1)
t_1 = t_0 + t # (N, MT)
tmask = (tmask >= t_0.unsqueeze(1)) & (tmask < t_1.unsqueeze(1)) # (N,T,MT)
tmask = tmask.any(2, keepdim=True) # (N, T, 1)
if f_0 is not None and f_0.numel() and f is not None and f.numel():
fmask = torch.arange(F, device=device).unsqueeze(0).unsqueeze(2) # (1, F,1)
f_1 = f_0 + f # (N, MF)
fmask = (fmask >= f_0.unsqueeze(1)) & (fmask < f_1.unsqueeze(1)) # (N,F,MF)
fmask = fmask.any(2).unsqueeze(1) # (N, 1, F)
if tmask is not None:
if fmask is not None:
tmask = tmask | fmask
new_feats = new_feats.masked_fill(tmask, 0.0)
elif fmask is not None:
new_feats = new_feats.masked_fill(fmask, 0.0)
return new_feats
[docs]
@script
@functional_wrapper("SpecAugment")
def spec_augment(
feats: torch.Tensor,
max_time_warp: float,
max_freq_warp: float,
max_time_mask: int,
max_freq_mask: int,
max_time_mask_proportion: float,
num_time_mask: int,
num_time_mask_proportion: float,
num_freq_mask: int,
interpolation_order: int,
lengths: Optional[torch.Tensor] = None,
training: bool = True,
) -> torch.Tensor:
_spec_augment_check_input(feats, lengths)
if not training:
return feats
params = spec_augment_draw_parameters(
feats,
max_time_warp,
max_freq_warp,
max_time_mask,
max_freq_mask,
max_time_mask_proportion,
num_time_mask,
num_time_mask_proportion,
num_freq_mask,
lengths,
)
return spec_augment_apply_parameters(feats, params, interpolation_order, lengths)
[docs]
class SpecAugment(torch.nn.Module):
r"""Perform warping/masking of time/frequency dimensions of filter bank features
SpecAugment [park2019]_ (and later [park2020]_) is a series of data transformations
for training data augmentation of time-frequency features such as Mel-scaled
triangular filter bank coefficients.
Default parameter values are from [park2020]_.
Parameters
----------
max_time_warp
A non-negative float specifying the maximum number of frames the chosen
random frame can be shifted left or right by in step 1. Setting to :obj:`0`
disables step 1.
max_freq_warp
A non-negative float specifying the maximum number of coefficients the chosen
random frequency coefficient index will be shifted up or down by in step 2.
Setting to :obj:`0` disables step 2.
max_time_mask
A non-negative integer specifying an absolute upper bound on the number of
sequential frames in time that can be masked out by a single mask. The minimum
of this upper bound and that from `max_time_mask_proportion` specifies the
actual maximum. Setting this, `max_time_mask_proportion`, `num_time_mask`,
or `num_time_mask_proportion` to :obj:`0` disables step 3.
max_freq_mask
A non-negative integer specifying the maximum number of sequential coefficients
in frequency that can be masked out by a single mask. Setting this or
`num_freq_mask` to :obj:`0` disables step 4.
max_time_mask_proportion
A value in the range :math:`[0, 1]` specifying a relative upper bound on the
number of squential frames in time that can be masked out by a single mask. For
batch element ``n``, the upper bound is ``int(max_time_mask_poportion *
length[n])``. The minimum of this upper bound and that from `max_time_mask`
specifies the actual maximum. Setting this, `max_time_mask`, `num_time_mask`,
or `num_time_mask_proportion` to :obj:`0` disables step 4.
num_time_mask
A non-negative integer specifying an absolute upper bound number of random masks
in time per batch element to create. Setting this, `num_time_mask_proportion`,
`max_time_mask`, or `max_time_mask_proportion` to :obj:`0` disables step 3.
Drawn i.i.d. and may overlap.
num_time_mask_proportion
A value in the range :math:`[0, 1]` specifying a relative upper bound on the
number of time masks per element in the batch to create. For batch element
``n``, the upper bound is ``int(num_time_mask_proportion * length[n])``. The
minimum of this upper bound and that from `num_time_mask` specifies the
actual maximum. Setting this, `num_time_mask`, `max_time_mask`, or
`max_time_mask_proportion` to :obj:`0` disables step 3. Drawn i.i.d. and may
overlap.
num_freq_mask
The total number of random masks in frequency per batch element to create.
Setting this or `max_freq_mask` to :obj:`0` disables step 4. Drawn i.i.d. and
may overlap.
interpolation_order
Controls order of interpolation of warping. 1 = linear (default for
[park2020]_). 2 = thin plate (default for [park2019]_). Higher orders are
possible at increased computational cost.
Call Parameters
---------------
feats : torch.Tensor
A tensor of shape ``(N, T, F)`` where ``N`` is the batch dimension, ``T`` is the
time (frames) dimension, and ``F`` is the frequency (coefficients per frame)
dimension. The original feature tensor.
lengths : Optional[torch.Tensor], optional
A long tensor of shape ``(N,)`` specifying the actual number of frames before
right-padding per batch element. That is, for batch index ``n``, only ``feats[n,
:lengths[n]]`` are valid.
Returns
-------
new_feats : torch.Tensor
The warped `feats` of shape ``(N, T, F)`` with some or all of the following
operations performed in order independently per batch index:
1. Choose a random frame along the time dimension. Warp `feats` such that
``feats[n, 0]`` and ``feats[n, lengths[n] - 1]`` are fixed, but that random
frame gets mapped to a random new location a few frames to the left or right.
2. Do the same for the frequency dimension.
3. Mask out (zero) one or more random-width ranges of frames in a random
location along the time dimension.
4. Do the same for the frequency dimension.
The original SpecAugment implementation only performs steps 1, 3, and 4; step 2
is a trivial extension.
Notes
-----
SpecAugment is only applied in training mode; in eval mode, ``new_feats == feats``.
There are a few differences between this implementation of warping and those you
might find online or described in the source paper [park2019]_. These require some
knowledge of what's happening under the hood and are unlikely to change the way you
use this function. We assume we're warping in time, though the following applies to
frequency warping as well.
First, the warp parameters are real- rather than integer-valued. You can set
`max_time_warp` or `max_freq_warp` to 0.5 if you'd like. The shift value drawn
between ``[0, max_time_warp]`` is also real-valued. Since the underlying warp
relies on interpolation between partial indices anyways (the vast majority of tensor
values will be the result of interpolation), there is no preference for
integer-valued parameters from a computational standpoint. Further, real-valued warp
parameters allow for a virtually infinite number of warps instead of just a few.
Finally, time warping is implemented by determining the transformation in one
dimension (time) and broadcasting it across the other (frequency), rather than
performing a two-dimensional warp. This is not in line with [park2019]_, but is
with [park2020]_. I have confirmed with the first author that the slight warping
of frequency that occurred due to the 2D warp was unintentional.
"""
__constants__ = (
"max_time_warp",
"max_freq_warp",
"max_time_mask",
"max_freq_mask",
"max_time_mask_proportion",
"num_time_mask",
"num_time_mask_proportion",
"num_freq_mask",
"interpolation_order",
)
max_time_warp: float
max_freq_warp: float
max_time_mask: int
max_freq_mask: int
max_time_mask_proportion: float
num_time_mask: int
num_time_mask_proportion: float
num_freq_mask: int
interpolation_order: int
def __init__(
self,
max_time_warp: float = 80.0,
max_freq_warp: float = 0.0,
max_time_mask: int = 100,
max_freq_mask: int = 27,
max_time_mask_proportion: float = 0.04,
num_time_mask: int = 20,
num_time_mask_proportion: float = 0.04,
num_freq_mask: int = 2,
interpolation_order: int = 1,
):
max_time_warp = argcheck.is_nonnegf(max_time_warp, "max_time_warp")
max_freq_warp = argcheck.is_nonnegf(max_freq_warp, "max_freq_warp")
max_time_mask = argcheck.is_nonnegi(max_time_mask, "max_time_mask")
max_freq_mask = argcheck.is_nonnegi(max_freq_mask, "max_freq_mask")
max_time_mask_proportion = argcheck.is_closed01(
max_time_mask_proportion, "max_time_mask_proportion"
)
num_time_mask = argcheck.is_nonnegi(num_time_mask, "num_time_mask")
num_time_mask_proportion = argcheck.is_closed01(
num_time_mask_proportion, "num_time_mask_proportion"
)
num_freq_mask = argcheck.is_nonnegi(num_freq_mask, "num_freq_mask")
interpolation_order = argcheck.is_posi(
interpolation_order, "interpolation_order"
)
super().__init__()
self.max_time_warp = max_time_warp
self.max_freq_warp = max_freq_warp
self.max_time_mask = max_time_mask
self.max_freq_mask = max_freq_mask
self.max_time_mask_proportion = max_time_mask_proportion
self.num_time_mask = num_time_mask
self.num_time_mask_proportion = num_time_mask_proportion
self.num_freq_mask = num_freq_mask
self.interpolation_order = interpolation_order
def extra_repr(self) -> str:
s = "warp_t={},max_f={},num_f={},max_t={},max_t_p={:.2f},num_t={}".format(
self.max_time_warp,
self.max_freq_mask,
self.num_freq_mask,
self.max_time_mask,
self.max_time_mask_proportion,
self.num_time_mask,
)
if self.max_freq_warp:
s += ",warp_f={}".format(self.max_freq_warp)
return s
def draw_parameters(
self, feats: torch.Tensor, lengths: Optional[torch.Tensor] = None
) -> SpecAugmentParams:
"""Randomly draw parameterizations of augmentations
Called as part of this layer's :func:`__call__` method.
Parameters
----------
feats
Time-frequency features of shape ``(N, T, F)``.
lengths
Long tensor of shape ``(N,)`` containing the number of frames before
padding.
Returns
-------
w_0 : torch.Tensor
If step 1 is enabled, of shape ``(N,)`` containing the source points in the
time warp (floatint-point). Otherwise, is empty.
w : torch.Tensor
If step 1 is enabled, of shape ``(N,)`` containing the number of frames to
shift the source point by (positive or negative) in the destination in time.
Positive values indicate a right shift. Otherwise is empty.
v_0 : torch.Tensor
If step 2 is enabled, of shape ``(N,)`` containing the source points in the
frequency warp (floating point). Otherwise is empty.
v : torch.Tensor
If step 2 is enabled, of shape ``(N,)`` containing the number of
coefficients to shift the source point by (positive or negative) in the
destination in time. Positive values indicate a right shift. Otherwise is
empty.
t_0 : torch.Tensor
If step 3 is enabled, of shape ``(N, M_T)`` where ``M_T`` is the number of
time masks specifying the lower index (inclusive) of the time masks.
Otherwise is empty.
t : torch.Tensor
If step 3 is enabled, of shape ``(N, M_T)`` specifying the number of frames
per time mask. Otherise is empty.
f_0 : torch.Tensor
If step 4 is enabled, of shape ``(N, M_F)`` where ``M_F`` is the number of
frequency masks specifying the lower index (inclusive) of the frequency
masks. Otherwise is empty.
f : torch.Tensor
If step 4 is enabled, of shape ``(N, M_F)`` specifying the number of
frequency coefficients per frequency mask. Otherwise is empty.
"""
return spec_augment_draw_parameters(
feats,
self.max_time_warp,
self.max_freq_warp,
self.max_time_mask,
self.max_freq_mask,
self.max_time_mask_proportion,
self.num_time_mask,
self.num_time_mask_proportion,
self.num_freq_mask,
lengths,
)
def apply_parameters(
self,
feats: torch.Tensor,
params: SpecAugmentParams,
lengths: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Use drawn parameters to apply augmentations
Called as part of this layer's :func:`__call__` method.
Parameters
----------
feats
Time-frequency features of shape ``(N, T, F)``.
params
All parameter tensors returned by :func:`draw_parameters`.
lengths
Tensor of shape ``(N,)`` containing the number of frames before padding.
Returns
-------
new_feats : torch.Tensor
Augmented time-frequency features of same shape as `feats`.
"""
return spec_augment_apply_parameters(
feats, params, self.interpolation_order, lengths
)
def reset_parameters(self) -> None:
pass
def forward(
self, feats: torch.Tensor, lengths: Optional[torch.Tensor] = None
) -> torch.Tensor:
if lengths is None:
# _spec_augment_check_input(feats)
lengths = torch.full(
(feats.size(0),), feats.size(1), dtype=torch.long, device=feats.device
)
if not self.training:
return feats
params = self.draw_parameters(feats, lengths)
return self.apply_parameters(feats, params, lengths)
__call__ = proxy(forward)