# Copyright 2022 Sean Robertson
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import warnings
from typing import Optional, overload
from typing_extensions import Literal, get_args
import torch
from . import config, argcheck
from ._compat import script
from ._wrappers import functional_wrapper, proxy
Reduction = Literal["mean", "sum", "none"]
[docs]
@functional_wrapper("FillAfterEndOfSequence")
def fill_after_eos(
tokens: torch.Tensor,
eos: int,
dim: int = 0,
fill: Optional[float] = None,
value: Optional[torch.Tensor] = None,
) -> torch.Tensor:
out = tokens if value is None else value
fill_ = float(eos) if fill is None else fill
# the clamp reduces the chances of overflow.
fill_mask = (tokens == eos).long().cumsum(dim).clamp_max(1).cumsum(dim) > 1
return out.masked_fill(fill_mask, fill_)
[docs]
class FillAfterEndOfSequence(torch.nn.Module):
"""Fill after the first end-of-sequence token with a value
Many Natural Language Processing tasks involve variable-length sequences ending with
special "end-of-sequence" (`eos`) tokens. This module finds the first instance of
`eos` and pads everything after that along the `dim` dimension with the value of
`fill`.
Parameters
----------
eos
The id of the end-of-sequence token.
dim
The sequence dimension of `tokens`.
fill
The value to fill with. If unset, set to `eos`.
Call Parameters
---------------
tokens : torch.Tensor
The token sequences. Of arbitrary shape, but must have dimension `dim`.
value : Optional[torch.Tensor], optional
`value` may be optionally specified as a tensor other than `tokens` to fill. It
must broadcast with `tokens` if specified. Otherwise `value` will be assumed to
be `tokens`.
Returns
-------
out : torch.Tensor
A tensor matching `tokens` (or `values` broadcasted with `tokens`, if `values`
was specified) except beyond the first instance of `eos` in `tokens`, after
which is `fill`.
Examples
--------
>>> T = 10
>>> tokens = torch.arange(T)
>>> tokens
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> fill_after_eos = FillAfterEndOfSequence(eos=T // 2, fill=-1)
>>> out = fill_after_eos(tokens)
>>> out
tensor([ 0, 1, 2, 3, 4, 5, -1, -1, -1, -1])
>>> logits = torch.eye(T)
tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])
>>> out = fill_after_eos(tokens.unsqueeze(1), logits)
>>> out
tensor([[ 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[ 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.]])
"""
__constants__ = "eos", "dim", "fill"
dim: int
eos: int
fill: float
def __init__(self, eos: int, dim: int = 0, fill: Optional[float] = None) -> None:
eos = argcheck.is_int(eos, "eos")
dim = argcheck.is_int(dim, "dim")
if fill is None:
fill = float(eos)
else:
fill = argcheck.is_float(fill, "fill")
super().__init__()
self.eos, self.dim, self.fill = eos, dim, fill
def forward(
self, tokens: torch.Tensor, value: Optional[torch.Tensor] = None
) -> torch.Tensor:
return fill_after_eos(tokens, self.eos, self.dim, self.fill, value)
__call__ = proxy(forward)
@torch.jit.script
def _lens_from_eos(tok: torch.Tensor, eos: int, dim: int) -> torch.Tensor:
# length to first eos (exclusive)
mask = tok.eq(eos)
x = torch.cumsum(mask, dim, dtype=torch.long)
max_, argmax = (x.eq(1) & mask).max(dim)
return argmax.masked_fill(max_.eq(0), tok.shape[dim])
@script
def _string_matching(
ref: torch.Tensor,
hyp: torch.Tensor,
eos: Optional[int],
include_eos: bool,
batch_first: bool,
ins_cost: float,
del_cost: float,
sub_cost: float,
warn: bool,
norm: bool = False,
return_mask: bool = False,
return_prf_dsts: bool = False,
exclude_last: bool = False,
padding: int = config.INDEX_PAD_VALUE,
return_mistakes: bool = False,
):
assert not return_mask or not return_prf_dsts
assert not exclude_last or (return_mask or return_prf_dsts)
if ref.dim() != 2 or hyp.dim() != 2:
raise RuntimeError("ref and hyp must be 2 dimensional")
mult = 1.0
if ins_cost == del_cost == sub_cost > 0.0:
# results are equivalent and faster to return
if not return_mistakes:
mult = ins_cost
ins_cost = del_cost = sub_cost = 1.0
return_mistakes = False
elif return_mistakes and warn:
warnings.warn(
"The behaviour for non-uniform error rates has changed after v0.3.0. "
"Please switch to edit_distance functions for old behaviour. Set "
"warn=False to suppress this warning"
)
if batch_first:
ref = ref.t()
hyp = hyp.t()
mistakes = del_mat = prefix_ers = torch.empty(0)
masks = []
ref = ref.detach()
hyp = hyp.detach()
max_ref_steps, batch_size = ref.shape
max_hyp_steps, batch_size_ = hyp.shape
device = ref.device
if batch_size != batch_size_:
raise RuntimeError(
"ref has batch size {}, but hyp has {}".format(batch_size, batch_size_)
)
if eos is not None:
ref_lens = _lens_from_eos(ref, eos, 0)
hyp_lens = _lens_from_eos(hyp, eos, 0)
if include_eos:
ref_eq_mask = ref_lens == max_ref_steps
ref_lens = ref_lens + 1
if ref_eq_mask.any():
if warn:
warnings.warn(
"include_eos=True, but a transcription in ref did not "
"contain the eos symbol ({}). To suppress this "
"warning, set warn=False".format(eos)
)
ref_lens = ref_lens - ref_eq_mask.to(ref_lens.dtype)
hyp_eq_mask = hyp_lens == max_hyp_steps
hyp_lens = hyp_lens + 1
if hyp_eq_mask.any():
if warn:
warnings.warn(
"include_eos=True, but a transcription in hyp did not "
"contain the eos symbol ({}). To suppress this "
"warning, set warn=False".format(eos)
)
hyp_lens = hyp_lens - hyp_eq_mask.to(hyp_lens.dtype)
# for n in range(batch_size):
# print(n, "ref", ref[: ref_lens[n], n])
# print(n, "hyp", hyp[: hyp_lens[n], n])
else:
ref_lens = torch.full(
(batch_size,), max_ref_steps, device=ref.device, dtype=torch.long
)
hyp_lens = torch.full(
(batch_size,), max_hyp_steps, device=ref.device, dtype=torch.long
)
# direct row down corresponds to insertion
# direct col right corresponds to a deletion
#
# we vectorize as much as we can. Neither substitutions nor insertions require
# values from the current row to be computed, and since the last row can't be
# altered, we can easily vectorize there. To vectorize deletions, we use del_matrix.
# It has entries
#
# 0 inf inf inf ...
# d 0 inf inf ...
# 2d d 0 inf ...
# ...
#
# Where "d" is del_cost. When we sum with the intermediate values of the next row
# "v" (containing the minimum of insertion and subs costs), we get
#
# v[0] inf inf inf ...
# v[0]+d v[1] inf inf ...
# v[0]+2d v[1]+d v[2] inf ...
# ...
#
# And we take the minimum of each row. The dynamic programming algorithm for
# levenshtein would usually handle deletions as:
#
# for i=1..|v|:
# v[i] = min(v[i], v[i-1]+d)
#
# if we unroll the loop, we get the minimum of the elements of each row of the above
# matrix
rrange = torch.arange(max_ref_steps + 1, device=device, dtype=torch.float)
if return_mistakes:
mistakes = rrange.unsqueeze(1).expand(max_ref_steps + 1, batch_size)
row = rrange * del_cost
else:
row = rrange * del_cost
del_mat = row.unsqueeze(1) - row
del_mat = del_mat + torch.full_like(del_mat, float("inf")).triu(1)
del_mat = del_mat.unsqueeze(-1) # (R + 1, R + 1, 1)
# print("k", 0)
row = row.unsqueeze(1).expand(max_ref_steps + 1, batch_size)
# for n in range(batch_size):
# print(n, "row", row[..., n])
if return_mask:
row_mask = torch.zeros(
(max_ref_steps, batch_size), device=device, dtype=torch.bool,
)
row_mask[0] = ref_lens > 0
# for n in range(batch_size):
# print(n, "row_mask", row_mask[..., n])
masks.append(row_mask)
elif return_prf_dsts:
prefix_ers = torch.empty(
(max_hyp_steps + (0 if exclude_last else 1), batch_size),
device=device,
dtype=torch.float,
)
prefix_ers[0] = ref_lens * (1.0 if return_mistakes else del_cost)
for hyp_idx in range(1, max_hyp_steps + (0 if exclude_last else 1)):
# print("k", hyp_idx)
not_done = (hyp_idx - (0 if exclude_last else 1)) < hyp_lens
last_row = row
ins_mask = (hyp_lens >= hyp_idx).float() # (N,)
neq_mask = (ref != hyp[hyp_idx - 1]).float() # (R + 1, N)
row = last_row + ins_cost * ins_mask
sub_row = last_row[:-1] + sub_cost * neq_mask
if return_mistakes:
# The kicker is substitutions over insertions over deletions.
pick_sub = row[1:] >= sub_row
row[1:] = torch.where(pick_sub, sub_row, row[1:])
last_mistakes = mistakes
mistakes = last_mistakes + ins_mask
msub_row = last_mistakes[:-1] + neq_mask
mistakes[1:] = torch.where(pick_sub, msub_row, mistakes[1:])
# FIXME(sdrobert): the min function behaves non-determinically r.n.
# (regardless of what the 1.7.0 docs say!) so techniques for extracting
# indices from the min are a wash. If we can get determinism, we can flip
# the 1 dimension if (del_mat + row) before the min and get the least idx
# min, which should have the fewest number of deletions.
for ref_idx in range(1, max_ref_steps + 1):
del_ = row[ref_idx - 1] + del_cost
pick_sub = del_ >= row[ref_idx]
row[ref_idx] = torch.where(pick_sub, row[ref_idx], del_)
mistakes[ref_idx] = torch.where(
pick_sub, mistakes[ref_idx], mistakes[ref_idx - 1] + 1.0
)
mistakes = torch.where(not_done, mistakes, last_mistakes)
else:
row[1:] = torch.min(row[1:], sub_row)
row, _ = (del_mat + row).min(1)
row = torch.where(not_done, row, last_row)
if return_mask:
# As proven in the OCD paper, the optimal targets are always the first
# character of a suffix of the reference transcript that remains to be
# aligned. The levenshtein operation corresponding to what we do with that
# target would be a matched substitution (i.e. hyp's next token is the OCD
# target, resulting in no change in cost from the prefix). Thus, given a
# levenshtein matrix for one of these OCD targets (which is this matrix,
# except for the final row), the minimal values on the final row sit on a
# diagonal from the minimal values of the current row.
#
# N.B. AFAICT this is the only case where we actually care what goes on in
# the invalid range of the row. The below masking could always be applied,
# but it's wasted effort otherwise.
row = row.masked_fill(rrange.unsqueeze(1) > ref_lens, float("inf"))
mins = row.min(0, keepdim=True)[0]
row_mask = (row[:-1] == mins) & not_done
# for n in range(batch_size):
# print(n, "row", row[..., n])
# print(n, "mins", row_mask[..., n])
# print(n, "row_mask", row_mask[..., n])
masks.append(row_mask)
elif return_prf_dsts:
if return_mistakes:
prefix_ers[hyp_idx] = mistakes.gather(0, ref_lens.unsqueeze(0)).squeeze(
0
)
else:
prefix_ers[hyp_idx] = row.gather(0, ref_lens.unsqueeze(0)).squeeze(0)
if return_mask:
mask = torch.stack(masks, 0)
mask = mask & (
torch.arange(max_ref_steps, device=device)
.unsqueeze(1)
.expand(max_ref_steps, batch_size)
< ref_lens
).unsqueeze(0)
return mask
elif return_prf_dsts:
prefix_ers = prefix_ers * mult
if norm:
prefix_ers = prefix_ers / ref_lens.to(row.dtype)
zero_mask = (ref_lens == 0).unsqueeze(0)
if zero_mask.any():
if warn:
warnings.warn(
"ref contains empty transcripts. Error rates will be "
"0 for prefixes of length 0, 1 otherwise. To suppress "
"this warning, set warn=False"
)
prefix_ers = torch.where(
zero_mask,
(
torch.arange(prefix_ers.size(0), device=device)
.gt(0)
.to(row.dtype)
.unsqueeze(1)
.expand_as(prefix_ers)
),
prefix_ers,
)
prefix_ers = prefix_ers.masked_fill(
(
torch.arange(prefix_ers.size(0), device=device)
.unsqueeze(1)
.ge(hyp_lens + (0 if exclude_last else 1))
),
padding,
)
if batch_first:
prefix_ers = prefix_ers.t()
return prefix_ers
if return_mistakes:
er = mistakes.gather(0, ref_lens.unsqueeze(0)).squeeze(0)
else:
er = row.gather(0, ref_lens.unsqueeze(0)).squeeze(0)
er = er * mult
if norm:
er = er / ref_lens.to(er.dtype)
zero_mask = ref_lens.eq(0)
if zero_mask.any():
if warn:
warnings.warn(
"ref contains empty transcripts. Error rates for entries "
"will be 1 if any insertion and 0 otherwise. To suppress "
"this warning, set warn=False"
)
er = torch.where(zero_mask, hyp_lens.gt(0).to(er.dtype), er)
return er
[docs]
@functional_wrapper("ErrorRate")
def error_rate(
ref: torch.Tensor,
hyp: torch.Tensor,
eos: Optional[int] = None,
include_eos: bool = False,
norm: bool = True,
batch_first: bool = False,
ins_cost: float = config.DEFT_INS_COST,
del_cost: float = config.DEFT_DEL_COST,
sub_cost: float = config.DEFT_SUB_COST,
warn: bool = True,
) -> torch.Tensor:
return _string_matching(
ref,
hyp,
eos,
include_eos,
batch_first,
ins_cost,
del_cost,
sub_cost,
warn,
norm=norm,
return_mistakes=True,
)
[docs]
@functional_wrapper("EditDistance")
def edit_distance(
ref: torch.Tensor,
hyp: torch.Tensor,
eos: Optional[int] = None,
include_eos: bool = False,
norm: bool = False,
batch_first: bool = False,
ins_cost: float = config.DEFT_INS_COST,
del_cost: float = config.DEFT_DEL_COST,
sub_cost: float = config.DEFT_SUB_COST,
warn: bool = True,
) -> torch.Tensor:
return _string_matching(
ref,
hyp,
eos,
include_eos,
batch_first,
ins_cost,
del_cost,
sub_cost,
warn,
norm=norm,
)
[docs]
@script
@functional_wrapper("OptimalCompletion")
def optimal_completion(
ref: torch.Tensor,
hyp: torch.Tensor,
eos: Optional[int] = None,
include_eos: bool = True,
batch_first: bool = False,
ins_cost: float = config.DEFT_INS_COST,
del_cost: float = config.DEFT_DEL_COST,
sub_cost: float = config.DEFT_SUB_COST,
padding: int = config.INDEX_PAD_VALUE,
exclude_last: bool = False,
warn: bool = True,
) -> torch.Tensor:
mask = _string_matching(
ref,
hyp,
eos,
include_eos,
batch_first,
ins_cost,
del_cost,
sub_cost,
warn,
return_mask=True,
exclude_last=exclude_last,
)
if not batch_first:
ref = ref.t()
H, R, N = mask.shape
device = ref.device
# if a token is set to true once, set all duplicates in the transcription to true
mask = (
mask.transpose(1, 2).unsqueeze(2) & (ref.unsqueeze(1) == ref.unsqueeze(2))
).any(
3
) # (H, N, R)
# sort the transcriptions and the mask
ref, src = ref.sort(1)
mask = mask.gather(2, src.expand_as(mask))
# set the mask to false for every duplicate token
mask_ = mask[..., :-1] & (ref[:, :-1] != ref[:, 1:]).expand(H, -1, -1)
mask = torch.cat([mask_, mask[..., -1:]], 2)
# scatter the tokens into the target buffer
targets_flat = ref.expand_as(mask).masked_select(mask)
counts = mask.sum(2) # (H, N)
C = int(counts.max().item())
targets = torch.full((H, N, C), padding, dtype=torch.long, device=device)
target_mask = counts.unsqueeze(-1) > torch.arange(C, device=device)
targets.masked_scatter_(target_mask, targets_flat)
if batch_first:
targets = targets.transpose(0, 1)
return targets
[docs]
@functional_wrapper("PrefixErrorRates")
def prefix_error_rates(
ref: torch.Tensor,
hyp: torch.Tensor,
eos: Optional[int] = None,
include_eos: bool = True,
norm: bool = True,
batch_first: bool = False,
ins_cost: float = config.DEFT_INS_COST,
del_cost: float = config.DEFT_DEL_COST,
sub_cost: float = config.DEFT_SUB_COST,
padding: int = config.INDEX_PAD_VALUE,
exclude_last: bool = False,
warn: bool = True,
) -> torch.Tensor:
return _string_matching(
ref,
hyp,
eos,
include_eos,
batch_first,
ins_cost,
del_cost,
sub_cost,
warn,
norm=norm,
return_prf_dsts=True,
exclude_last=exclude_last,
padding=padding,
return_mistakes=True,
)
[docs]
@functional_wrapper("PrefixEditDistances")
def prefix_edit_distances(
ref: torch.Tensor,
hyp: torch.Tensor,
eos: Optional[int] = None,
include_eos: bool = True,
norm: bool = False,
batch_first: bool = False,
ins_cost: float = config.DEFT_INS_COST,
del_cost: float = config.DEFT_DEL_COST,
sub_cost: float = config.DEFT_SUB_COST,
padding: int = config.INDEX_PAD_VALUE,
exclude_last: bool = False,
warn: bool = True,
) -> torch.Tensor:
return _string_matching(
ref,
hyp,
eos,
include_eos,
batch_first,
ins_cost,
del_cost,
sub_cost,
warn,
norm=norm,
return_prf_dsts=True,
exclude_last=exclude_last,
padding=padding,
return_mistakes=False,
)
_SM_PARAM_DICT = {
"ref": """\
ref : torch.Tensor
A long tensor of shape ``(R, N)`` where ``R`` is the reference sequence
dimension and ``N`` is the batch dimension. Stores the reference (gold-standard)
sequences.
""",
"hyp": """\
hyp : torch.Tensor
A long tensor of shape ``(H, N)`` where ``H`` is the hypothesis sequence
dimension. Stores the hypothesis (machine-generated) sequences.
""",
"eos": """\
eos
A special token in `ref` and `hyp` whose first occurrence in each batch
indicates the end of a transcript. This allows for variable-length transcripts
in the batch.
""",
"include_eos": """\
include_eos
Whether to include the first instance of `eos` found in both `ref` and `hyp` as
valid tokens to be computed as part of the rate. This is useful when gauging
if a model is learning to emit the `eos` properly, but is not usually included
in an evaluation. Only the first `eos` per transcript is included.
""",
"norm": """\
norm
If :obj:`True`, will normalize the distance by the number of tokens in the
reference sequence (making the returned value a divergence).
""",
"batch_first": """\
batch_first
If :obj:`True`, the first two dimensions of `ref`, `hyp`, and the return value
are transposed from those above.
""",
"ins_cost": """\
ins_cost
The cost of an adding an extra token to a sequence in `ref`.
""",
"del_cost": """\
del_cost
The cost of removing a token from a sequence in `ref`.
""",
"sub_cost": """\
sub_cost
The cost of swapping a token from `ref` with one from `hyp`.
""",
"warn": """\
warn : bool, optional
Whether to display warnings on irregularities. Currently, this can happen in
three ways:
1. If :obj:`True` and `ins_cost`, `del_cost`, or `sub_cost` is not 1, a warning
about a difference in computations will be raised. See the below warning for
more info.
2. If :obj:`True` and `norm` is :obj:`True`, will warn when a reference
transcription has zero length
3. If `eos` is set and `include_eos` is :obj:`True`, will warn when a transcript
does not include an `eos` symbol.
""",
"padding": """\
padding
The value to right-pad unequal-length sequences with. Defauls to
:obj:`pydrobert.torch.config.INDEX_PAD_VALUE`.
""",
"exclude_last": """\
exclude_last
If true, will exclude the final prefix, consisting of the entire transcript,
from the return value. It will be of shape ``(H, N, U)``
""",
"reduction": """\
reduction
Specifies the reduction to be applied to the output. ``'none'``: no
reduction will be applied. ``'sum'``: the output will be summed. ``'mean'``:
the output will be averaged.
""",
"ignore_index": """\
ignore_index
Specify a target value that is ignored and does not contribute to the input
gradient. Should not be set to `eos` when `include_eos` is :obj:`True`.
""",
"weight": """\
weight
A tensor of shape ``(V,)`` specifying the rescaling weight to assign to each
class. If unset, no rescaling is performed.
""",
"sub_avg": """\
sub_avg
Whether to subtract the average error rate from each pathwise error
rate.
""",
}
class _StringMatching(torch.nn.Module, metaclass=abc.ABCMeta):
__constants__ = (
"eos",
"include_eos",
"batch_first",
"ins_cost",
"del_cost",
"sub_cost",
"warn",
)
eos: Optional[int]
include_eos: bool
batch_first: bool
ins_cost: float
del_cost: float
sub_cost: float
warn: bool
def __init__(
self, eos, include_eos, batch_first, ins_cost, del_cost, sub_cost, warn
):
eos = argcheck.is_int(eos, "eos", True)
include_eos = argcheck.is_bool(include_eos, "include_eos")
batch_first = argcheck.is_bool(batch_first, "batch_first")
ins_cost = argcheck.is_float(ins_cost, "ins_cost")
del_cost = argcheck.is_float(del_cost, "del_cost")
sub_cost = argcheck.is_float(sub_cost, "sub_cost")
warn = argcheck.is_bool(warn, "warn")
super().__init__()
self.eos, self.include_eos, self.batch_first = eos, include_eos, batch_first
self.ins_cost, self.del_cost, self.sub_cost = ins_cost, del_cost, sub_cost
self.warn = warn
def extra_repr(self) -> str:
return ", ".join(f"{x}={getattr(self, x)}" for x in self.__constants__)
@abc.abstractmethod
def forward(self, ref: torch.Tensor, hyp: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
[docs]
class EditDistance(_StringMatching):
__constants__ = (
"eos",
"include_eos",
"norm",
"batch_first",
"ins_cost",
"del_cost",
"sub_cost",
"warn",
)
norm: bool
def __init__(
self,
eos: Optional[int] = None,
include_eos: bool = False,
norm: bool = False,
batch_first: bool = False,
ins_cost: float = config.DEFT_INS_COST,
del_cost: float = config.DEFT_DEL_COST,
sub_cost: float = config.DEFT_SUB_COST,
warn: bool = True,
):
norm = argcheck.is_bool(norm, "norm")
super().__init__(
eos, include_eos, batch_first, ins_cost, del_cost, sub_cost, warn
)
self.norm = norm
__doc__ = f"""Compute an edit distance over a batch of references and hypotheses
An `Edit Distance <https://en.wikipedia.org/wiki/Edit_distance>`__ quantifies how
dissimilar two token sequences are as the total cost of transforming a reference
sequence into a hypothesis sequence. There are three operations that can be
performed, each with an associated cost: adding an extra token to the reference,
removing a token from the reference, or swapping a token in the reference with a
token in the hypothesis.
Parameters
----------
{"".join(_SM_PARAM_DICT[c] for c in __constants__)}
Call Parameters
---------------
{"".join(_SM_PARAM_DICT[c] for c in ('ref', 'hyp'))}
Returns
-------
ed : torch.Tensor
A tensor of shape ``(N,)`` of the edit distances.
Notes
-----
This module returns identical values (modulo a bug fix) to :func:`error_rate` up
to `v0.3.0` (though the default of `norm` has changed to :obj:`False`). For more
details on the distinction between this module and the new :func:`ErrorRate`,
please see that module's documentation.
"""
def forward(self, ref: torch.Tensor, hyp: torch.Tensor) -> torch.Tensor:
return edit_distance(
ref,
hyp,
self.eos,
self.include_eos,
self.norm,
self.batch_first,
self.ins_cost,
self.del_cost,
self.sub_cost,
self.warn,
)
__call__ = proxy(forward)
[docs]
class PrefixEditDistances(_StringMatching):
__constants__ = (
"eos",
"include_eos",
"norm",
"batch_first",
"ins_cost",
"del_cost",
"sub_cost",
"padding",
"exclude_last",
"warn",
)
norm: bool
padding: int
exclude_last: bool
def __init__(
self,
eos: Optional[int] = None,
include_eos: bool = True,
norm: bool = False,
batch_first: bool = False,
ins_cost: float = config.DEFT_INS_COST,
del_cost: float = config.DEFT_DEL_COST,
sub_cost: float = config.DEFT_SUB_COST,
padding: int = config.INDEX_PAD_VALUE,
exclude_last: bool = False,
warn: bool = True,
):
norm = argcheck.is_bool(norm, "norm")
padding = argcheck.is_int(padding, "padding")
exclude_last = argcheck.is_bool(exclude_last, "exclude_last")
super().__init__(
eos, include_eos, batch_first, ins_cost, del_cost, sub_cost, warn
)
self.norm, self.padding, self.exclude_last = norm, padding, exclude_last
__doc__ = f"""Compute the edit distance between ref and each prefix of hyp
Parameters
----------
{"".join(_SM_PARAM_DICT[c] for c in __constants__)}
Call Parameters
---------------
{"".join(_SM_PARAM_DICT[c] for c in ('ref', 'hyp'))}
Returns
-------
prefix_eds : torch.Tensor
A tensor of shape ``(H + 1, N)`` of the edit distances for each prefix of each
hypothesis, starting from the empty prefix.
Parameters
----------
{"".join(_SM_PARAM_DICT[c] for c in __constants__)}
Notes
-----
This module returns identical values (modulo a bug fix) to
:func:`prefix_error_rates` (and :class:`PrefixErrorRates`) up to `v0.3.0` (though
the default of `norm` has changed to :obj:`False`). For more details on the
distinction between this module and the new :func:`prefix_error_rates`, please
consult the documentation of :class:`ErrorRate`.
"""
def forward(self, ref: torch.Tensor, hyp: torch.Tensor) -> torch.Tensor:
return prefix_edit_distances(
ref,
hyp,
self.eos,
self.include_eos,
self.norm,
self.batch_first,
self.ins_cost,
self.del_cost,
self.sub_cost,
self.padding,
self.exclude_last,
self.warn,
)
__call__ = proxy(forward)
[docs]
class ErrorRate(_StringMatching):
__constants__ = (
"eos",
"include_eos",
"norm",
"batch_first",
"ins_cost",
"del_cost",
"sub_cost",
"warn",
)
norm: bool
def __init__(
self,
eos: Optional[int] = None,
include_eos: bool = False,
norm: bool = True,
batch_first: bool = False,
ins_cost: float = config.DEFT_INS_COST,
del_cost: float = config.DEFT_DEL_COST,
sub_cost: float = config.DEFT_SUB_COST,
warn: bool = True,
):
norm = argcheck.is_bool(norm, "norm")
super().__init__(
eos, include_eos, batch_first, ins_cost, del_cost, sub_cost, warn
)
self.norm = norm
__doc__ = f"""Calculate error rates over a batch of references and hypotheses
An error rate is the total number of insertions, deletions, and substitutions
between a reference (gold-standard) and hypothesis (generated) transcription,
normalized by the number of elements in a reference. Consult the Wikipedia article
on the `Levenshtein distance <https://en.wikipedia.org/wiki/Levenshtein_distance>`__
for more information.
Parameters
----------
{"".join(_SM_PARAM_DICT[c] for c in __constants__)}
Call Parameters
---------------
{"".join(_SM_PARAM_DICT[c] for c in ('ref', 'hyp'))}
Returns
-------
ed : torch.Tensor
A tensor of shape ``(N,)`` of the error rates.
Warnings
--------
Up to and including `v0.3.0`, :func:`error_rate` computed a normalized `Edit
distance <https://en.wikipedia.org/wiki/Edit_distance>`__ instead of an error rate.
The latter can be considered the total weighted cost of insertions, deletions, and
substitutions (as per `ins_cost`, `del_cost`, and `sub_cost`), whereas the former is
the sum of the number of mistakes. The old behaviour of returning the cost is now in
:func:`edit_distance` and :class:`EditDistance` (though `norm` is :obj:`False` by
default). For speech recognition evaluation, this module or :func:`error_rate` is
the one to use. However, if you are using the default costs, ``ins_cost == del_cost
== sub_cost == 1``, there should be no numerical difference between the two.
"""
def forward(self, ref: torch.Tensor, hyp: torch.Tensor) -> torch.Tensor:
return error_rate(
ref,
hyp,
self.eos,
self.include_eos,
self.norm,
self.batch_first,
self.ins_cost,
self.del_cost,
self.sub_cost,
self.warn,
)
__call__ = proxy(forward)
[docs]
class PrefixErrorRates(_StringMatching):
__constants__ = [
"eos",
"include_eos",
"norm",
"batch_first",
"ins_cost",
"del_cost",
"sub_cost",
"padding",
"exclude_last",
"warn",
]
norm: bool
padding: int
exclude_last: bool
def __init__(
self,
eos: Optional[int] = None,
include_eos: bool = True,
norm: bool = True,
batch_first: bool = False,
ins_cost: float = config.DEFT_INS_COST,
del_cost: float = config.DEFT_DEL_COST,
sub_cost: float = config.DEFT_SUB_COST,
padding: int = config.INDEX_PAD_VALUE,
exclude_last: bool = False,
warn: bool = True,
):
norm = argcheck.is_bool(norm, "norm")
padding = argcheck.is_int(padding, "padding")
exclude_last = argcheck.is_bool(exclude_last, "exclude_last")
super().__init__(
eos, include_eos, batch_first, ins_cost, del_cost, sub_cost, warn
)
self.norm, self.padding, self.exclude_last = norm, padding, exclude_last
__doc__ = f"""Compute the error rate between ref and each prefix of hyp
Parameters
----------
{"".join(_SM_PARAM_DICT[c] for c in __constants__)}
Call Parameters
---------------
{"".join(_SM_PARAM_DICT[c] for c in ('ref', 'hyp'))}
Returns
-------
prefix_ers : torch.Tensor
A tensor of shape ``(H + 1, N)`` containing the error rates for each prefix of
each hypothesis, starting from the empty prefix.
Warnings
--------
The values returned by :func:`prefix_error_rates` (and thus this module) changed
after `v0.3.0`. The old behaviour can be found in :class:`PrefixEditDistances`
(though with `norm` defaulting to :obj:`False`). Consult the warning in
:class:`ErrorRate` for more info.
"""
def forward(self, ref: torch.Tensor, hyp: torch.Tensor) -> torch.Tensor:
return prefix_error_rates(
ref,
hyp,
self.eos,
self.include_eos,
self.norm,
self.batch_first,
self.ins_cost,
self.del_cost,
self.sub_cost,
self.padding,
self.exclude_last,
self.warn,
)
__call__ = proxy(forward)
[docs]
class OptimalCompletion(_StringMatching):
__constants__ = (
"eos",
"include_eos",
"batch_first",
"ins_cost",
"del_cost",
"sub_cost",
"padding",
"exclude_last",
"warn",
)
padding: int
exclude_last: bool
def __init__(
self,
eos: Optional[int] = None,
include_eos: bool = True,
batch_first: bool = False,
ins_cost: float = config.DEFT_INS_COST,
del_cost: float = config.DEFT_DEL_COST,
sub_cost: float = config.DEFT_SUB_COST,
padding: int = config.INDEX_PAD_VALUE,
exclude_last: bool = False,
warn: bool = True,
):
padding = argcheck.is_int(padding, "padding")
exclude_last = argcheck.is_bool(exclude_last, "exclude_last")
super().__init__(
eos, include_eos, batch_first, ins_cost, del_cost, sub_cost, warn
)
self.padding, self.exclude_last = padding, exclude_last
__doc__ = f"""Return a mask of next tokens of a minimum edit distance prefix
Parameters
----------
{"".join(_SM_PARAM_DICT[c] for c in __constants__)}
Call Parameters
---------------
{"".join(_SM_PARAM_DICT[c] for c in ('ref', 'hyp'))}
Returns
-------
optimals : torch.Tensor
A long tensor of shape ``(H + 1, N, U)`` where ``U <= R`` of the unique tokens
that could be added to each prefix of the hypothesis such that some remaining
suffix concatenated to the prefix would result in a minimal edit distance. See
below for an example.
Examples
--------
Consider the reference text "foot" and the hypothesis text "bot". The below shows
the matrix used to calculate edit distances between them::
\\ _ f o o t
_ 0 1 2 3 4
b 1 1 2 3 4
o 2 2 1 2 3
t 3 3 2 2 2
If ``prefix_len == 0``, then the prefix is "", and "f" (from the suffix "foot") is
the only subsequent token that would not increase the edit distance from that of the
prefix (0). If ``prefix_len == 1``, then the prefix is "b". To arrive at the minimum
edit distance for "b", one either treats "b" as an insertion or a substitution for
"f", yielding suffixes "foot" and "oot". Thus, the subsequent token could be "f" or
"o". For the prefix "bo", the minimum edit distance is achieved by first
substituting "f" for "b", then substituting "o" for "o", resulting in the suffix
"ot" and the next optimal character "o". Finally, for ``prefix_len == 3`` and prefix
"bot", there are many operations that can produce the minimum edit distance of 2,
resulting in one of the suffixes "ot", "t", and "". The latter suffix requires no
more tokens and so any operation would increase the edit distance. Thus the optimal
next tokens could be "o" or "t".
Plugging "foot" and "bot" into this function, we get the prefixes:
>>> ref_text, hyp_text = "foot", "bot"
>>> ref = torch.tensor([ord(c) for c in ref_text]).unsqueeze(1)
>>> hyp = torch.tensor([ord(c) for c in hyp_text]).unsqueeze(1)
>>> optimal = optimal_completion(ref, hyp).squeeze(1)
>>> for prefix_len, o_for_pr in enumerate(optimal):
... o_for_pr = o_for_pr.masked_select(o_for_pr.ge(0)).tolist()
... print('prefix={{}}: {{}}'.format(
... hyp_text[:prefix_len], ','.join([chr(i) for i in o_for_pr])))
prefix=: f
prefix=b: f,o
prefix=bo: o
prefix=bot: o,t
See Also
--------
pydrobert.torch.layers.HardOptimalCompletionDistillationLoss
A loss function that uses these optimal completions to train a model
"""
def forward(self, ref: torch.Tensor, hyp: torch.Tensor) -> torch.Tensor:
return optimal_completion(
ref,
hyp,
self.eos,
self.include_eos,
self.batch_first,
self.ins_cost,
self.del_cost,
self.sub_cost,
self.padding,
self.exclude_last,
self.warn,
)
__call__ = proxy(forward)
@overload
def hard_optimal_completion_distillation_loss(
logits: torch.Tensor,
ref: torch.Tensor,
hyp: torch.Tensor,
eos: Optional[int] = None,
include_eos: bool = True,
batch_first: bool = False,
ins_cost: float = config.DEFT_INS_COST,
del_cost: float = config.DEFT_DEL_COST,
sub_cost: float = config.DEFT_SUB_COST,
weight: Optional[torch.Tensor] = None,
reduction: Reduction = "mean",
ignore_index: int = -2,
warn: bool = True,
) -> torch.Tensor:
...
[docs]
@script
@functional_wrapper("HardOptimalCompletionDistillationLoss")
def hard_optimal_completion_distillation_loss(
logits: torch.Tensor,
ref: torch.Tensor,
hyp: torch.Tensor,
eos: Optional[int] = None,
include_eos: bool = True,
batch_first: bool = False,
ins_cost: float = config.DEFT_INS_COST,
del_cost: float = config.DEFT_DEL_COST,
sub_cost: float = config.DEFT_SUB_COST,
weight: Optional[torch.Tensor] = None,
reduction: str = "mean",
ignore_index: int = -2,
warn: bool = True,
) -> torch.Tensor:
if logits.dim() != 3:
raise RuntimeError("logits must be 3 dimensional")
if logits.shape[:-1] != hyp.shape:
raise RuntimeError("first two dims of logits must match hyp shape")
if include_eos:
if eos is not None and ((eos < 0) or (eos >= logits.size(-1))):
raise RuntimeError(f"If include_eos=True, eos ({eos}) must be a class idx")
if eos is not None and eos == ignore_index:
raise RuntimeError(
f"If include_eos=True, eos cannot equal ignore_index ({eos}"
)
optimals = optimal_completion(
ref,
hyp,
eos=eos,
include_eos=include_eos,
batch_first=batch_first,
ins_cost=ins_cost,
del_cost=del_cost,
sub_cost=sub_cost,
padding=ignore_index,
exclude_last=True,
warn=warn,
)
max_unique_next = optimals.size(-1)
logits = logits.unsqueeze(2).expand(-1, -1, max_unique_next, -1)
logits = logits.contiguous()
loss = torch.nn.functional.cross_entropy(
logits.flatten(0, -2),
optimals.flatten(),
weight=weight,
ignore_index=ignore_index,
reduction="none",
).view_as(optimals)
padding_mask = optimals == ignore_index
loss = loss.masked_fill(padding_mask, 0.0).sum(2)
loss = loss / (~padding_mask).sum(2).clamp_min(1)
if reduction == "mean":
seq_dim = 1 if batch_first else 0
loss = (
loss.sum(seq_dim) / (~padding_mask).any(2).sum(seq_dim).clamp_min(1)
).mean()
elif reduction == "sum":
loss = loss.sum()
elif reduction != "none":
raise RuntimeError(f"'{reduction}' is not a valid value for reduction")
return loss
[docs]
class HardOptimalCompletionDistillationLoss(torch.nn.Module):
__constants__ = (
"eos",
"include_eos",
"batch_first",
"ins_cost",
"del_cost",
"sub_cost",
"reduction",
"ignore_index",
)
__doc__ = f"""A categorical loss function over optimal next tokens
Optimal Completion Distillation (OCD) [sabour2018]_ tries to minimize the train/test
discrepancy in transcriptions by allowing seq2seq models to generate whatever
sequences they want, then assigns a per-step loss according to whatever next token
would set the model on a path that minimizes the edit distance in the future.
In its "hard" version, the version used in the paper, the OCD loss function is
simply a categorical cross-entropy loss of each hypothesis token's distribution
versus those optimal next tokens, averaged over the number of optimal next tokens:
.. math::
loss(logits_h) = \\frac{{-\\log Pr(s_h|logits_t)}}{{|S_h|}}
Where :math:`s_h \\in S_h` are tokens from the set of optimal next tokens given
:math:`hyp_{{\\leq h}}` and `ref`. The loss is decoupled from an exact prefix of
`ref`, meaning that `hyp` can be longer or shorter than `ref`.
Parameters
----------
{"".join(_SM_PARAM_DICT[c] for c in __constants__)}
Call Parameters
---------------
logits : torch.Tensor
A tensor of shape ``(H, N, V)`` where ``H`` is the hypothesis sequence
dimension, ``N`` is the batch dimension, and ``V`` is the vocabulary size.
Stores the unnormalized log-probabilities over the next token of each prefix
(except the last) within `hyp`.
ref : torch.Tensor
A long tensor of shape ``(R, N)`` where ``R`` is the reference sequence
dimension. Stores the reference (gold-standard) sequences.
hyp : torch.Tensor
A long tensor of shape ``(H, N)``. Stores the hypothesis (machine-generated)
sequences.
Returns
-------
loss : torch.Tensor
The loss. If `reduction` is ``'sum'`` or ``'mean'``, it is a scalar value.
Otherwise of shape ``(H, N)``.
See Also
--------
pydrobert.torch.util.optimal_completion
Used to determine the optimal next token set :math:`S`
pydrobert.torch.util.random_walk_advance
For producing a random `hyp` based on `logits` if the underlying
model producing `logits` is auto-regressive. Also provides an example
of sampling non-auto-regressive models
"""
eos: Optional[int]
include_eos: bool
batch_first: bool
ins_cost: float
del_cost: float
sub_cost: float
reduction: str
ignore_index: int
def __init__(
self,
eos: Optional[int] = None,
include_eos: bool = True,
batch_first: bool = False,
ins_cost: float = config.DEFT_INS_COST,
del_cost: float = config.DEFT_DEL_COST,
sub_cost: float = config.DEFT_SUB_COST,
weight: Optional[torch.Tensor] = None,
reduction: Reduction = "mean",
ignore_index: int = config.INDEX_PAD_VALUE,
):
eos = argcheck.is_int(eos, "eos", True)
include_eos = argcheck.is_bool(include_eos, "include_eos")
batch_first = argcheck.is_bool(batch_first, "batch_first")
ins_cost = argcheck.is_float(ins_cost, "ins_cost")
del_cost = argcheck.is_float(del_cost, "del_cost")
sub_cost = argcheck.is_float(sub_cost, "sub_cost")
weight = argcheck.is_tensor(weight, "weight", True)
reduction = argcheck.is_in(reduction, get_args(Reduction), "reduction")
ignore_index = argcheck.is_int(ignore_index, "ignore_index")
super().__init__()
self.eos, self.include_eos, self.batch_first = eos, include_eos, batch_first
self.ins_cost, self.del_cost, self.sub_cost = ins_cost, del_cost, sub_cost
self.reduction, self.ignore_index = reduction, ignore_index
self.register_buffer("weight", weight)
def forward(
self,
logits: torch.Tensor,
ref: torch.Tensor,
hyp: torch.Tensor,
warn: bool = True,
) -> torch.Tensor:
return hard_optimal_completion_distillation_loss(
logits,
ref,
hyp,
self.eos,
self.include_eos,
self.batch_first,
self.ins_cost,
self.del_cost,
self.sub_cost,
self.weight,
self.reduction,
self.ignore_index,
warn,
)
__call__ = proxy(forward)
@overload
def minimum_error_rate_loss(
log_probs: torch.Tensor,
ref: torch.Tensor,
hyp: torch.Tensor,
eos: Optional[int] = None,
include_eos: bool = True,
sub_avg: bool = True,
batch_first: bool = False,
norm: bool = True,
ins_cost: float = config.DEFT_INS_COST,
del_cost: float = config.DEFT_DEL_COST,
sub_cost: float = config.DEFT_SUB_COST,
reduction: Reduction = "mean",
warn: bool = True,
) -> torch.Tensor:
...
[docs]
@script
@functional_wrapper("MinimumErrorRateLoss")
def minimum_error_rate_loss(
log_probs: torch.Tensor,
ref: torch.Tensor,
hyp: torch.Tensor,
eos: Optional[int] = None,
include_eos: bool = True,
sub_avg: bool = True,
batch_first: bool = False,
norm: bool = True,
ins_cost: float = config.DEFT_INS_COST,
del_cost: float = config.DEFT_DEL_COST,
sub_cost: float = config.DEFT_SUB_COST,
reduction: str = "mean",
warn: bool = True,
) -> torch.Tensor:
if log_probs.dim() != 2:
raise RuntimeError("log_probs must be 2 dimensional")
if hyp.dim() != 3:
raise RuntimeError("hyp must be 3 dimensional")
if ref.dim() not in (2, 3):
raise RuntimeError("ref must be 2 or 3 dimensional")
if batch_first:
batch_size, samples, max_hyp_steps = hyp.shape
if ref.dim() == 2:
ref = ref.unsqueeze(1).repeat(1, samples, 1)
if (ref.shape[:2] != (batch_size, samples)) or (
ref.shape[:2] != log_probs.shape
):
raise RuntimeError(
"ref and hyp batch_size and sample dimensions must match"
)
max_ref_steps = ref.size(-1)
ref = ref.view(-1, max_ref_steps)
hyp = hyp.view(-1, max_hyp_steps)
else:
max_hyp_steps, batch_size, samples = hyp.shape
if ref.dim() == 2:
ref = ref.unsqueeze(-1).repeat(1, 1, samples)
if (ref.shape[1:] != (batch_size, samples)) or (
ref.shape[1:] != log_probs.shape
):
raise RuntimeError(
"ref and hyp batch_size and sample dimensions must match"
)
max_ref_steps = ref.size(0)
ref = ref.view(max_ref_steps, -1)
hyp = hyp.view(max_hyp_steps, -1)
if samples < 2:
raise RuntimeError(f"Batch must have at least two samples, got {samples}")
er = error_rate(
ref,
hyp,
eos=eos,
include_eos=include_eos,
norm=norm,
batch_first=batch_first,
ins_cost=ins_cost,
del_cost=del_cost,
sub_cost=sub_cost,
warn=warn,
).view(batch_size, samples)
if sub_avg:
er = er - er.mean(1, keepdim=True)
loss = er * torch.nn.functional.softmax(log_probs, 1)
if reduction == "mean":
loss = loss.mean()
elif reduction == "sum":
loss = loss.sum()
elif reduction != "none":
raise RuntimeError(f"'{reduction}' is not a valid value for reduction")
return loss
[docs]
class MinimumErrorRateLoss(torch.nn.Module):
__constants__ = (
"eos",
"include_eos",
"sub_avg",
"batch_first",
"norm",
"ins_cost",
"del_cost",
"sub_cost",
"reduction",
)
__doc__ = f"""Error rate expectation normalized over some number of transcripts
Proposed in [prabhavalkar2018]_ though similar ideas had been explored previously.
Given a subset of all possible token sequences and their associated probability mass
over that population, this loss calculates the probability mass normalized over the
subset, then calculates the expected error rate over that normalized distribution.
That is, given some sequences :math:`s \\in S \\subseteq P`, the loss for a given
reference transcription :math:`s^*` is
.. math::
\\mathcal{{L}}(s, s^*) = \\frac{{Pr(s) ER(s, s^*)}}{{\\sum_{{s'}} Pr(s')}}
This is an exact expectation over :math:`S` but not over :math:`P`. The larger the
mass covered by :math:`S`, the closer the expectation is to the population -
especially so for an n-best list (though it would be biased).
Parameters
----------
{"".join(_SM_PARAM_DICT[c] for c in __constants__)}
Call Parameters
---------------
log_probs : torch.Tensor
A tensor of shape ``(N, M)`` where ``N`` is the batch size and ``M`` is the
number of samples providing the log joint probabilities of every sample path.
ref : torch.Tensor
A tensor of either of shape ``(R, N)`` or ``(R, N, M)`` where ``R`` is the
maximum reference length containing the reference (gold-standard)
transcriptions. Whether `ref` is 2D or 3D changes how the loss is calculated.
hyp : torch.Tensor
A long tensor of shape ``(H, N, M)`` where ``H`` is the maximum hypothesis size
containing the hypothesis (machine-generated) transcriptions.
{_SM_PARAM_DICT["warn"]}
Returns
-------
loss : torch.Tensor
The loss. If `reduction` is ``'sum'`` or ``'mean'``, it is a scalar value.
Otherwise of shape ``(N,M)``. If `ref` is 2D, the loss for sample ``m`` of
batch element ``n`` is
.. math::
loss_{{n, m}} = SoftMax(log\\_probs)[ER(hyp_{{n, m}}, ref_n) - \\mu_n]
where where :math:`\\mu_n` is the average error rate for the ``M`` hypotheses
in batch element ``n``. :math:`\\mu_n` is dropped if `sub_avg` is :obj:`True`.
If `ref` is 3D, each hypothesis is compared against a unique reference:
.. math::
loss_{{n, m}} = SoftMax(log\\_probs)[ER(hyp_{{n, m}}, ref_{{n,m}}) - \\mu_n]
Notes
-----
A previous version of this module incorporated a Maximum Likelihood Estimate (MLE)
into the loss as in [prabhavalkar2018]_, which required `logits` instead of
`log_probs`. This was overly complicated, given the user can easily incorporate the
additional loss term herself by using :class:`torch.nn.CrossEntropyLoss`. Take a
look at the example below for how to recreate this
Examples
--------
Assume here that `logits` is the output of some neural network, and that `hyp` has
somehow been produced from that (e.g. a beam search or random walk). We combine this
loss function with a cross-entropy/MLE term to sort-of recreate [prabhavalkar2018]_.
>>> from pydrobert.torch.util import sequence_log_probs
>>> steps, batch_size, num_classes, eos, padding = 30, 20, 10, 0, -1
>>> samples, lmb = 10, .01
>>> logits = torch.randn(
... steps, samples, batch_size, num_classes, requires_grad=True)
>>> hyp = torch.randint(num_classes, (steps, samples, batch_size))
>>> ref_lens = torch.randint(1, steps + 1, (batch_size,))
>>> ref_lens[0] = steps
>>> ref = torch.nn.utils.rnn.pad_sequence(
... [torch.randint(1, num_classes, (x,)) for x in ref_lens],
... padding_value=padding,
... )
>>> ref[ref_lens - 1, range(batch_size)] = eos
>>> ref = ref.unsqueeze(1).repeat(1, samples, 1)
>>> mer = MinimumErrorRateLoss(eos=eos)
>>> mle = torch.nn.CrossEntropyLoss(ignore_index=padding)
>>> log_probs = sequence_log_probs(logits, hyp, eos=eos)
>>> l = mer(log_probs, ref, hyp)
>>> l = l + lmb * mle(logits.view(-1, num_classes), ref.flatten())
>>> l.backward()
See Also
--------
pydrobert.torch.util.beam_search_advance
For getting an n-best list into `hyp` and some `log_probs`.
pydrobert.torch.util.random_walk_advance
For getting a random sample into `hyp`
pydrobert.torch.util.sequence_log_probs
For converting token log probs (or logits) to sequence log probs
"""
eos: Optional[int]
include_eos: bool
sub_avg: bool
norm: bool
ins_cost: float
del_cost: float
sub_cost: float
reduction: str
def __init__(
self,
eos: Optional[int] = None,
include_eos: bool = True,
sub_avg: bool = True,
batch_first: bool = False,
norm: bool = True,
ins_cost: float = config.DEFT_INS_COST,
del_cost: float = config.DEFT_DEL_COST,
sub_cost: float = config.DEFT_SUB_COST,
reduction: Literal["mean", "none", "sum"] = "mean",
):
eos = argcheck.is_int(eos, "eos", True)
include_eos = argcheck.is_bool(include_eos, "include_eos")
sub_avg = argcheck.is_bool(sub_avg, "sub_avg")
batch_first = argcheck.is_bool(batch_first, "batch_first")
norm = argcheck.is_bool(norm, "norm")
ins_cost = argcheck.is_float(ins_cost, "ins_cost")
del_cost = argcheck.is_float(del_cost, "del_cost")
sub_cost = argcheck.is_float(sub_cost, "sub_cost")
reduction = argcheck.is_in(reduction, get_args(Reduction), "reduction")
super().__init__()
self.eos, self.include_eos, self.sub_avg = eos, include_eos, sub_avg
self.batch_first, self.norm, self.reduction = batch_first, norm, reduction
self.ins_cost, self.del_cost, self.sub_cost = ins_cost, del_cost, sub_cost
def forward(
self,
log_probs: torch.Tensor,
ref: torch.Tensor,
hyp: torch.Tensor,
warn: bool = True,
) -> torch.Tensor:
return minimum_error_rate_loss(
log_probs,
ref,
hyp,
self.eos,
self.include_eos,
self.sub_avg,
self.batch_first,
self.norm,
self.ins_cost,
self.del_cost,
self.sub_cost,
self.reduction,
warn,
)
__call__ = proxy(forward)