modules

Custom PyTorch modules

Notes

To document torch.nn.Module subclasses, we add a special heading called “Call Parameters” to the docstring which, along with “Returns”, specify the signature of the module’s __call__() method. The header “Parameters” refers to what values the module are initialized with. The general usage pattern is:

>>> module = Module(*params)
>>> returns = module(*call_params)

Attention

class pydrobert.torch.modules.GlobalSoftAttention(query_size, key_size, dim=0)[source]

Parent class for soft attention mechanisms on an entire input sequence

Global soft attention mechansims [bahdanau2015] are a way of getting rid of one variable-length sequence dimension T in an input key using a weighted sum of a tensor value that is informed by some other tensor, query. The weights are dictated by the function score(). Usually, this is in the context of encoder-decoder architectures, which is explained here.

Assume query is a tensor of shape (batch_size, query_size) representing a single hidden state of a decoder RNN. Assume key is a tensor of shape (T, batch_size, key_size) representing the encoder output, dim == 0 to specify that the variable-length dimension of key is the zero-th dimension, and value == key. The output out will be a tensor of shape (batch_size, key_size). Letting \(t\) index the dim-th dimension:

\[out = \sum_t a_t value_t\]

a is the attention vector. In our example, a will be of shape (T, batch_size). a is the result of a softmax over the dim-th dimension of another tensor e of shape (T, batch_size) with an optional mask

\[a = softmax(e * mask - (1 - mask) \infty, dim)\]

mask (if specified) is of shape (T, batch_size) and will set a to zero wherever the mask is zero. mask can be used to indicate padded values when key consists of variable-length sequences.

e is the result of a score function over key and query

\[e = score(query, key)\]

score() is implemented by subclasses of GlobalSoftAttention.

Parameters:
  • query_size (int) – The length of the last dimension of the query argument

  • key_size (int) – The length of the last dimension of the key argument

  • dim (int, optional) – The sequence dimension of the key argument

Call Parameters:
  • query (torch.Tensor) – A tensor of shape (A*, query_size) representing the queries. (A*) must broadcast with (B*, C*) from key, value, and mask.

  • key (torch.Tensor) – A tensor of shape (B*, T, C*, key_size) representing the keys. (B*, C*) must broadcast with (A*) from query.

  • value (torch.Tensor) – A tensor of shape (B*, T, C*, D*) representing the values. (B*, C*) must broadcast with (A*) from query.

  • mask (Optional[torch.Tensor]) – An optional boolean tensor of shape (B*, T, C*) which indicates which values of the key should be kept (False means zero-out). If unset, assumed to be entirely True.

Returns:

out (torch.Tensor) – The output tensor of shape (E*, D*), where (E*) is the result of broadcasting (A*) with (B*, C*).

Examples

A simple auto-regressive decoder using soft attention on encoder outputs with “concat”-style attention

>>> T, batch_size, encoded_size, hidden_size = 100, 5, 30, 124
>>> num_classes, start, eos, max_decoder_steps = 20, -1, 0, 100
>>> encoded_lens = torch.randint(1, T + 1, (batch_size,))
>>> len_mask = torch.where(
...     torch.arange(T).unsqueeze(-1) < encoded_lens,
...     torch.tensor(1),
...     torch.tensor(0),
... )
>>> encoded = torch.randn(T, batch_size, encoded_size)
>>> rnn = torch.nn.RNNCell(encoded_size + 1, hidden_size)
>>> ff = torch.nn.Linear(hidden_size, num_classes)
>>> attention = ConcatSoftAttention(hidden_size, encoded_size)
>>> h = torch.zeros((batch_size, hidden_size))
>>> y = torch.full((1, batch_size), -1, dtype=torch.long)
>>> for _ in range(max_decoder_steps):
>>>     if y[-1].eq(eos).all():
>>>         break
>>>     context = attention(h, encoded, encoded, len_mask)
>>>     cat = torch.cat([context, y[-1].unsqueeze(-1).float()], 1)
>>>     h = rnn(cat)
>>>     logit = ff(h)
>>>     y_next = logit.argmax(-1).masked_fill(y[-1].eq(eos), eos)
>>>     y = torch.cat([y, y_next.unsqueeze(0)], 0)

See also

Attention and Transformer Networks

GlobalSoftAttention is compatible with a variety of inputs. This tutorial gives a toy transformer network to illustrate broadcasting semantics.

class pydrobert.torch.modules.ConcatSoftAttention(query_size, key_size, dim=0, bias=False, hidden_size=1000)[source]

Attention where query and key are concatenated, then fed into an MLP

Proposed in [luong2015], though quite similar to that proposed in [bahdanau2015], the score function for this layer is:

\[e = \sum_i v_i \tanh(\sum_c W_{ic} [query, key]_c)\]

For some learned matrix \(W\) and vector \(v\), where \([query, key]\) indicates concatenation along the last axis. query and key will be expanded to fit their broadcast dimensions. \(W\) has shape (inter_size, key_size) and \(v\) has shape (hidden_size,)

Parameters:
  • query_size (int) –

  • key_size (int) –

  • dim (int) –

  • bias (bool) – Whether to add bias term b \(W [query, key] + b\)

  • hidden_size (int) –

Call Parameters:
Returns:

out (torch.Tensor)

See also

GlobalSoftAttention

For a description of how to call this module, how it works, etc.

class pydrobert.torch.modules.DotProductSoftAttention(size, dim=0, scale_factor=1.0)[source]

Global soft attention with dot product score function

From [luong2015], the score function for this attention mechanism is

\[e = scale\_factor \sum_i query_i key_i\]

Where \(i\) indexes the last dimension of both the query and key

Parameters:
  • size (int) – The size of the final dimension of both query and key.

  • dim (int) –

  • scale_factor (float) – A floating point to multiply the each \(e\) with. Usually 1, but if set to \(1 / size\), you’ll get the scaled dot-product attention of [vaswani2017].

Call Parameters:
Returns:

out (torch.Tensor)

See also

GlobalSoftAttention

For a description of how to call this module, how it works, etc.

class pydrobert.torch.modules.GeneralizedDotProductSoftAttention(query_size, key_size, dim=0, bias=False)[source]

Dot product soft attention with a learned matrix in between

The “general” score function from [luong2015], the score function for this attention mechanism is

\[e = \sum_q query_q \sum_k W_{qk} key_k\]

For some learned matrix \(W\). \(q\) indexes the last dimension of query and \(k\) the last dimension of key

Parameters:
  • query_size (int) –

  • key_size (int) –

  • dim (int) –

  • bias (bool) – Whether to add a bias term b: \(W key + b\)

Call Parameters:
Returns:

out (torch.Tensor)

See also

GlobalSoftAttention

For a description of how to call this module, how it works, etc.

class pydrobert.torch.modules.MultiHeadedAttention(query_size, key_size, value_size, num_heads, single_head_attention, out_size=None, d_v=None, bias_WQ=False, bias_WK=False, bias_WV=False, bias_WC=False)[source]

Perform attention over a number of heads, concatenate, and project

Multi-headed attention was proposed in [vaswani2017]. It can be considered a wrapper around standard GlobalSoftAttention that also performs GlobalSoftAttention, but with more parameters. The idea is to replicate transformed versions of the query, key, and value num_heads times. Letting \(h\) index the head:

\[\begin{split}query_h = W^Q_h query \\ key_h = W^K_h key \\ value_h = W^V_h value\end{split}\]

If query is of shape (..., query_size), \(W^Q_h\) is a learned matrix of shape (query_size, d_q) that acts on the final dimension of query. Likewise, \(W^K_h\) is of shape (key_size, d_k) and \(W^V_h\) is of shape (value_size, d_v). Note here that the last dimension of value must also be provided in value_size, unlike in other attention layers.

Each head is then determined via a wrapped GlobalSoftAttention instance, single_head_attention:

\[head_h = single\_head\_attention(query_h, key_h, value_h, mask)\]

Where mask is repeated over all \(h\).

Since each \(head_h\) has the same shape, they can be concatenated along the last dimension to get the tensor \(cat\) of shape (..., d_v * num_heads), which is linearly transformed into the output

\[out = W^C cat\]

With a learnable matrix \(W^C\) of shape (d_v * num_heads, out_size). out has a shape (..., out_size)

Parameters:
  • query_size (int) – The size of the last dimension of the query being passed to this module (not the size of a head’s query).

  • key_size (int) – The size of the last dimension of the key being passed to this module (not the size of a head’s key).

  • value_size (int) – The size of the last dimension of the value being passed to this module (not the size of a head’s value).

  • num_heads (int) – The number of heads to spawn.

  • single_head_attention (GlobalSoftAttention) – An instance of a subclass of GlobalSoftAttention responsible for processing a head. single_head_attention attention will be used to derive the sequence dimension (dim) of key via single_head_attention.dim, the size of a head’s query d_k via single_head_attention.query_size, and the size of a head’s key via single_head_attention.key_size.

  • out_size (Optional[int]) – The size of the last dimension of out. If unset, the default is to match value_size.

  • d_v (Optional[int]) – The size of the last dimension of a head’s value. If unset, will default to max(1, value_size // num_heads).

  • bias_WQ (bool) – Whether to add a bias term to \(W^Q\).

  • bias_WK (bool) – Whether to add a bias term to \(W^K\).

  • bias_WV (bool) – Whether to add a bias term to \(W^V\).

  • bias_WC (bool) – Whether to add a bias term to \(W^C\).

Call Parameters:
Returns:

out (torch.Tensor) – The output tensor of shape (E*, D*, value_size), where (E*) is the result of broadcasting (A*) with (B*, C*).

Decoding

class pydrobert.torch.modules.BeamSearch(lm, width, eos=None, finish_all_paths=False, pad_value=-100)[source]

Perform beam search on the outputs of a SequentialLanguageModel

Beam search is a heuristic algorithm that keeps track of width most promising paths in the beam by probability, distributed by the language model lm.

A path continues to be extended until it is either pruned or emits an end-of-sequence (eos) symbol (if set). The search ends for a batch element when its highest probability path ends with an eos or all paths end with an eos (depending on the setting of finish_all_paths). The search ends for the entire batch either when the search for all batch elements have ended or max_iters steps has been reached, whichever comes first. It is therefore necessary to set at least one of eos or max_iters.

Parameters:
  • lm (ExtractableSequentialLanguageModel) – The language model responsible for producing distributions over the next token type.

  • width (int) – The beam width.

  • eos (Optional[int]) – The end of sequence type. If set, must be in-vocabulary (according to lm.vocab_size). Either eos or max_iters must be set.

  • finish_all_paths (bool) – Applicable only when eos is set. If True, waits for all paths in all batches’ beams to emit an eos symbol before stopping. If False, only the highest probability path need end with an eos before stopping.

  • pad_value (int) – The value to pad frozen paths with. See the below note for more information.

Call Parameters:
  • initial_state (Dict[str, torch.Tensor], optional) – Whatever state info must be initially passed to the lm before any sequences are generated.

  • batch_size (Optional[int], optional) – Specifies the batch size (N*,). If set, (N*,) == (batch_size,) and a beam search will be run separately over each of the batch elements. If unset, (N*,) == (,) and a single search will be performed. See the below note for more information.

  • max_iters (Optional[int], optional) – The maximum number of tokens to generate in the paths before returning. Either eos or max_iters must be set.

Returns:

  • y (torch.Tensor) – A long tensor of shape (S, N*, width) containing the width paths.

  • y_lens (torch.Tensor) – A long tensor of shape (N*, width) of the lengths of the corresponding paths including the first instance of eos, if it exists. Only the tokens in y[:y_lens[..., k], ..., k] are valid.

  • y_log_probs (torch.Tensor) – A tensor of shape (N*, width) containing the log probabilities of the paths.

Variables:

device_buffer (torch.Tensor) – An empty tensor which determines the device of return values. The device is inferred on initialization from lm.parameters(), if possible. The device can be forced by using the module’s to() method, which also shifts the parameters of lm to the device.

Warning

Return values will always contain width prefixes, regardless of whether this is possible. The log probabilities of invalid prefixes will be set to -float("inf") and will populate the latter indices of the beam. Since this cannot be distinguished from a zero-probability path (log 0 = -inf), care must be taken by the user to avoid confusing them.

Notes

When batch_size is unset, a single search starting with a single empty prefix is run and its results reported. This is appropriate for language models which do not condition on any batched input in initial_state. Since the search is deterministic, running the search multiple times on the same empty prefix (as if batched) would duplicate results. In this setting, the return values (y, y_lens, and y_log_probs) have no batch dimension.

batch_size is appropriate when the search is being conditioned on some batched input viz. initial_state, such as images or audio features. In these cases, batch_size should match the batch dimension of the batched input. batch_size empty prefixes will be initialized and passed to the lm. The return values will have a corresponding batch dimensions.

The introduction of batching via batch_size raises the question of what to do when one or more batch elements have finished the search while others continue. In order to return consistent results regardless of the number of elements in the batch, we freeze the results of completed batch elements while the remainder continue. If batch element n is completed by step t, y_lens[n] and y_log_probs[n] will be kept the same regardless of whether the remaining batch elements require a step t + 1. Because y needs to grow while the remaining batch elements are unfinished, completed sequences will be right-padded with the value pad_value. Such sequences may or may not have ended with an eos (if set) prior to padding, depending on the value of finish_all_paths.

class pydrobert.torch.modules.CTCGreedySearch(blank_idx=-1, batch_first=False, is_probs=False)[source]

CTC greedy search

The CTC greedy search picks the path with the highest probability class in logits for each element in the sequence. The path (log-)probability is the (sum) product of the chosen type (log-probabilities). The output sequences are the resulting sequence of class labels with blanks and duplicates removed.

Parameters:
  • blank_idx (int) – Which index along the class dimension specifices the blank label

  • batch_first (bool) – If False, logits is of shape (T, N, V) and paths is of shape (T, N).

  • is_probs (bool) – If True, logits will be considered a normalized probability distribution instead of an un-normalized log-probability distribution. The return value max_ will take the product of sequence probabilities instead of the sum.

Call Parameters:
  • logits (torch.Tensor) – A tensor of shape (N, T, V) where T is the sequence dimension, N is the batch dimension, and V is the number of classes including the blank label. logits[n, t, :] represent the unnormalized log-probabilities of the labels at time t in batch element n.

  • in_len (torch.Tensor, optional) – A long tensor of shape (N,) providing the lengths of the sequence in the batch. For a given batch element n, only the values of logits in the slice logits[n, :in_lens[n]] will be considered valid.

Returns:

  • max_ (torch.Tensor) – A tensor of shape (N,) containing the total log-probability of the greedy path.

  • paths (torch.Tensor) – A long tensor of shape (N, T) which stores the reduced greedy paths.

  • out_lens (torch.Tensor) – A long tensor of shape (N,) which specifies the lengths of the greedy paths within paths: for a given batch element n, the reduced greedy path is the sequence in the range paths[n, :out_lens[n]]. The values of paths outside this range are undefined.

class pydrobert.torch.modules.CTCPrefixSearch(width, beta=0.2, lm=None, valid_mixture=False)[source]

Perform a CTC prefix search with optional shallow fusion

A Connectionist Temporal Classification [graves2006] prefix search is similar to a beam search, but a fixed number of (reduced) prefixes are maintained in the beam rather than a fixed number of paths. Reduced paths contain no blank labels.

Shallow fusion [gulcehre2015] is enabled by initializing this module with lm. Shallow fusion updates the probability of extending a prefix \(y_{1..t-1}\) with a new token math:v (\(v\) is not blank) with the following equation

\[\log S(y_t=v|y_{1..t-1}) = \log P_{ctc}(y_t=v) + \beta \log P_{lm}(y_t = v|y_{1..t-1})\]

The resulting value \(log S(y_t=v)\) is not technically a log-probability.

Parameters:
  • width (int) – The number of prefixes to keep track of per step.

  • beta (float) – The mixing coefficient \(\beta\) used when performing shallow fusion.

  • lm (Optional[MixableSequentialLanguageModel]) – If set, the language model used in shallow fusion. Specifying lm will restrict the extended vocabulary size of logits to be one more than that of lm: lm.vocab_size == V.

  • valid_mixture (bool) –

    If True, an alternate equation for shallow fusion is employed:

    \[\begin{split}\begin{multline} S(y_t|\ldots) = (1 - \beta) P_{ctc}(y_t=v) \hfill \\ + \beta P_{lm}(y_t=v|\ldots) \sum_{v' \neq blank} P_{ctc}(y_t = v'|\ldots) \end{multline}\end{split}\]

    for \(\beta \in [0, 1]\). Unlike in regular shallow fusion, the resulting value is a log-probability over the extended vocabulary.

Call Parameters:
  • logits (torch.Tensor) – A tensor of shape (T, N, V + 1) s.t. logits[t, n] represents the unnormalized log-probabilities over the extended vocabulary (including blanks) at step t of batch element n. The blank type logits are assumed to be stored in the final index of the vocabulary: logits[..., V].

  • logit_lens (Optional[torch.Tensor]) – An optional tensor of shape (N,) s.t., for a given batch index n, only the values in the slice logits[:lens[n], n] are valid. If unset then all sequences are assumed to be of length T.

  • initial_state (Optional[Dict[str, torch.Tensor]]) – Whatever state info must be passed to the lm prior to generating sequences, if specified.

Returns:

  • y (torch.Tensor) – A long tensor of shape (S, N, width) containing the width prefixes per batch element, S <= T.

  • y_lens (torch.Tensor) – A long tensor of shape (N, width) of the lengths of the corresponding prefixes: for each batch element n and prefix k, only the tokens y[:y_lens[n, k], n, k] are valid. Note that for all k, y_lens[n, k] <= logit_lens[n].

  • y_probs (torch.Tensor) – A tensor of shape (N, width) containing those prefix’s estimated (not log) probabilities. Prefixes are ordered in decreasing probability (y_probs[n, k] >= y_probs[n, k + 1]).

Warning

The blank index, effectively V, is different from the default index of torch.nn.CTCLoss, 0. We chose this in order to avoid confusion between the index set of logits and the index set of lm: this way, the interpretation of the indices up to but excluding V in both refer to the same type/label.

Return values will always contain width prefixes, regardless of whether this is possible. The probabilities of invalid prefixes will be set to -float("inf") and will populate the latter indices of the beam.

Notes

The CTC prefix search is often called a beam search in the literature. We stick with the name from [graves2006] as it is entirely possible to apply a normal beam search to CTC logits, only removing blank labels after the search. Doing so would be faster and may not lead to much decrease in performance if logits is sufficiently “peaky.”

class pydrobert.torch.modules.RandomWalk(lm, eos=None)[source]

Perform a random walk on the outputs of a SequentialLanguageModel

A random walk iteratively builds a sequence of tokens by sampling the next token given a prefix of tokens.

A path continues to be extended until it emits an end-of-sequence (eos) symbol (if set). The walk ends for a batch as soon as all paths in the batch have ended or max_iters has been reached, whichever comes first. It is therefore necessary to set at least one of eos or max_iters.

Parameters:
  • lm (SequentialLanguageModel) – The language model responsible for producing distributions over the next token type.

  • eos (Optional[int]) – The end of sequence type. If set, must be in-vocabulary (according to lm.vocab_size). Either eos or max_iters must be set.

Call Parameters:
  • initial_state (dict, optional) – Whatever state info must be initially passed to the lm before any sequences are generated.

  • batch_size (Optional[int], optional) – Specifies the batch size (N*,). If set, (N*,) == (batch_size,) and a walk will be performed for each batch element independently. If unset, (N*,) == (,) and a single walk will be performed. See the below note for more information.

  • max_iters (Optional[int], optional) – Specifies the maximum number of steps to take in the walk. Either eos or max_iters must be set.

Returns:

  • y (torch.Tensor) – A long tensor of shape (S, N*) containing the paths.

  • y_lens (torch.Tensor) – A long tensor of shape (N*,) of the lengths of the corresponding paths including the first instance of eos, if it exists. For batch element n, only the tokens in y[:y_lens[n], n] are valid.

  • y_log_probs (torch.Tensor) – A tensor of shape (N*,) containing the log probabilities of the paths.

Variables:

device_buffer (torch.Tensor) – An empty tensor which determines the device of return values. The device is inferred on initialization from lm.parameters(), if possible. The device can be forced by using the module’s to() method, which also shifts the parameters of lm to the device.

Notes

The interface for RandomWalk is similar to that of BeamSearch. When batch_size is unset, a single empty prefix is initialized, a single draw/walk is performed, and the return values (y, y_lens, and y_log_probs) have no batch dimension. When batch_size is set, batch_size empty prefixes are initialized, batch_size draws/walks are performed, and the return values have a batch dimension of batch_size.

Setting batch_size remains useful when the language model conditions on some batched input through initial_state; each walk will be assigned some different batch element. Because the results of the walk are random, batch_size can also be used in the unconditioned case to draw batch_size elements from the same distribution. This is in contrast with BeamSearch in which increasing batch_size in the unconditioned case merely repeats the same search batch_size times.

See also

pydrobert.torch.distributions.SequentialLanguageModelDistribution

A wrapper around a RandomWalk instance which allows it to be treated as a distribution.

class pydrobert.torch.modules.SequenceLogProbabilities(dim=0, eos=None)[source]

Calculate joint log probability of sequences

Letting \(t\) index the step dimension and \(b\) index all other shared dimensions of logits and hyp, this function outputs a tensor log_probs of the log-joint probability of sequences in the batch:

\[\log Pr(samp_b = hyp_b) = \log \left( \prod_t Pr(samp_{b,t} == hyp_{b,t}; logits_{b,t})\right)\]

\(logits_{b,t}\) (with the last dimension free) characterizes a categorical distribution over num_classes tokens via a softmax function. We assume \(samp_{b,t}\) is independent of \(samp_{b',t'}\) given \(logits_t\).

If eos (end-of-sentence) is set, the first occurrence at \(b,t\) is included in the sequence, but all \(b,>t\) are ignored.

Parameters:
  • dim (int) – The sequence dimension of logits.

  • eos (Optional[int]) – If set, specifies the end-of-sequence token index in the last dimension of logits.

Call Parameters:
  • logits (torch.Tensor or torch.nn.utils.rnn.PackedSequence) – A tensor of shape (A*, T, B*, num_classes) where T enumerates the time/step dim-th dimension. The unnormalized log-probabilities over type. Alternatively, logits may be a packed sequence. In this case, eos is ignored.

  • hyp (torch.Tensor) – A long tensor of shape (A*, T, B*). The token sequences over time. Any values of hyp not in [0, num_classes) will be considered padding and ignored.

Returns:

log_probs (torch.Tensor) – A tensor of shape (A*, B*) containing the log probabilties of the sequences in hyp.

Notes

PackedSequence instances with enforce_sorted=False first sort sequences by length. The sort is not guaranteed to be deterministic if some entries have equal length. To avoid the possibility that logits and hyp are sorted differently, we require hyp to always be a torch.Tensor.

PyTorch < 1.8.0 cannot infer whether logits is a torch.nn.utils.rnn.PackedSequence in scripting mode.

Features

class pydrobert.torch.modules.ChunkBySlices(mode='constant', value=0.0)[source]

Chunk input using slices, padding where necessary

Parameters:
  • mode (Literal['constant', 'reflect', 'replicate']) – How to pad slices that go beyond the sequence lengths. See PadVariable for more information on the modes.

  • value (float) – The value to pad with when mode == 'constant'.

Call Parameters:
  • x (torch.Tensor) – A tensor of shape (N, T, *) where N is the batch index and T is the sequence index.

  • slices (torch.Tensor) – A long tensor of shape (N, 2) containing pairs start, end, where start and end are the start (inclusive) and end (exclusive) indices, respectively. Any slices exceeding segment boundaries will be padded according to the mode specified.

  • lens (torch.Tensor, optional) – An optional long tensor of shape (N,) specifying the sequence lengths. Only the values in the range x[n, :lens[n]] are considered part of the sequence of batch element n. If unspecified, all sequences of x are assumed to be of length T.

Returns:

  • chunked (torch.Tensor) – A tensor of shape (N, T', *) of chunks of x. Besides, T', chunked matches the shape of x.`

  • chunked_lens (torch.Tensor) – A long tensor of shape (N,) with the same interpretation as lens, but for chunked instead.

Warning

Negative indices in slices in Python are usually interpreted as an offset left from the end of the sequence. Here, however, negative indices indicate an offset left from the start of the sequence. Those values will be interpreted as padding and be added to the chunk.

See also

PadVariable

For more details on how padding works.

SliceSpectData

Can be used to determine slices for SpectDataSet features. In this case, x = x[sources] and lens = lens[sources] should be passed to this module (using the return value sources from SliceSpectData).

ChunkTokenSequenceBySlices

A similar purpose, but specifically for token sequences from a SpectDataSet.

class pydrobert.torch.modules.ChunkTokenSequencesBySlices(partial=False, retain=False)[source]

Chunk token sequences with segments in slices

Parameters:
  • partial (bool) – If True, a segment of refs whose interval partially overlaps with the slice will be included in chunked. Otherwise, segments in ref must fully overlap with slices (i.e. be contained within).

  • retain (bool) – If True, tokens kept from refs will retain their original boundary values. Otherwise, boundaries will become relative to the start frame of slices.

Call Parameters:
  • refs (torch.Tensor) – A long tensor of shape (N, R, 3) containing triples tok, start, end, where tok is the token id, start is the start frame (inclusive) of the segment, and end is its end frame (exclusive). A negative start or end is treated as a missing boundary and will automatically exclude the triple from the chunk. ref may also be a 2-dimensional long tensor (N, R) of tokens, excluding segment boundaried. However, the return values will always be empty.

  • slices (torch.Tensor) – A long tensor of shape (N, 2) containing pairs start, end, where start and end are the start (inclusive) and end (exclusive) indices, respectively.

  • ref_lens (torch.Tensor, optional) – An optional long tensor of shape (N,) specifying the token sequence lengths. Only the values in the range refs[n, :ref_lens[n]] are considered part of the sequence of batch element n. If unspecified, all token sequences of refs are assumed to be of length R.

Returns:

  • chunked (torch.Tensor) – A long tensor of shape (N, R', 3) of the chunked token sequences.

  • chunked_lens (torch.Tensor) – A long tensor of shape (N,) with the same interpretation as ref_lens, but for chunked instead.

Warning

Negative indices in slices in Python are usually interpreted as an offset left from the end of the sequence. In slices, however, negative indices indicate an offset left from the start of the sequence. In refs, negative indices indicate a missing boundary and are thrown out. Negative indices in slices can impact the returned segment boundaries in chunked.

See also

SliceSpectData

Can be used to determine appropriate slices. In this case, refs = refs[sources] and ref_lens = ref_lens[sources] should be passed to this module (using the return value sources from SliceSpectData).

ChunkBySlices

A similar purpose, but for input with an explicit dimension for slicing, such as feats or alis from SpectDataSet.

class pydrobert.torch.modules.DenseImageWarp(indexing='hw', mode='bilinear', padding_mode='border')[source]

Warp an input image with per-pixel flow vectors

This reproduces the functionality of Tensorflow’s dense_image_warp, except image is in NCHW order instead of NHWC order. It wraps torch.nn.functional.grid_sample().

Parameters:
  • indexing (Literal['hw', 'wh']) – If indexing is "hw", flow[..., 0] = h, the height index, and flow[..., 1] = w is the width index. If "wh", flow[..., 0] = w and flow[..., 1] = h. The default in TF is "hw", whereas torch’s grid_sample is "wh"

  • mode (Literal['bilinear', 'nearest']) – The method of interpolation. Either use bilinear interpolation or the nearest pixel value. The TF default is "bilinear"

  • padding_mode (Literal['border', 'zero', 'reflection']) – Controls how points outside of the image boundaries are interpreted. "border": copy points at around the border of the image. "zero": use zero-valued pixels. "reflection": reflect pixels into the image starting from the boundaries.

Call Parameters:
  • image (torch.Tensor) – A tensor of shape (N, C, H, W) where N is the batch dimension, C is the channel dimension, H is the height dimension, and W is the width dimension.

  • flow (torch.Tensor) – A ftensor of shape (N, H, W, 2).

Returns:

warped (torch.Tensor) –

A warped image of shape (N, C, H, W) such that:

warped[n, c, h, w] = image[n, c, h - flow[n, h, w, 0], w - flow[n, h, w, 1]]

If the reference indices h - ... and w - ... are not integers, the value is interpolated from the neighboring pixel values.

Warning

flow is not an optical flow. Please consult the TF documentation for more details.

class pydrobert.torch.modules.FeatureDeltas(dim=-1, time_dim=-2, concatenate=True, order=2, width=2, pad_mode='replicate', value=0.0)[source]

Compute deltas of features

Letting \(x\) be some input tensor with the time_dim-th dimension representing the evolution of features over time. Denote that dimension with indices \(t\) and the dimension of the order of the deltas with \(u\). The \(0\)-th order deltas are just \(x\) itself; higher order deltas are calculated recursively as

\[x[t, u] = \sum_{w=-width}^{width} x[t + w, u - 1] \frac{w}{\sum_{w'} w'^2}.\]

Deltas can be seen as a rolling averages: first-order deltas akin to first-order derivatives; second-order to second-order, and so on.

Parameters:
  • dim (int) – The dimension along which resulting deltas will be stored.

  • time_dim (int) – The dimension along which deltas are calculated.

  • concatenate (bool) – If True, delta orders are merged into a single axis with the previous occupants of the dimension dim via concatenation. Otherwise, a new dimension is stacked into the location dim.

  • order (int) – The non-negative maximum order of deltas.

  • width (int) – Controls the width of the averaging window.

  • pad_mode (Literal['replicate', 'constant', 'reflect', 'circular']) – How to pad edges to ensure the same size output. See torch.nn.functional.pad() for more details.

  • value (float) – The value used in constant padding.

Call Parameters:

x (torch.Tensor)

Returns:

deltas (torch.Tensor) – Has the same shape as x except for one dimension. If concatenate is false, a new dimension is inserted into x at position dim of size order + 1. If concatenate is true, then the dim-th dimension of deltas is order + 1 times the length of that of x.

class pydrobert.torch.modules.MeanVarianceNormalization(dim=-1, mean=None, std=None, eps=1.1754943508222875e-38)[source]

Normalize features according to mean and variance statistics

Given input x, population mean mean, population standard deviation std, and some small value eps, mean-variance normalization for the i-th element of the dim-th dimension of x is defined as

y[..., i, ...] = (x[..., i, ...] - mean[i]) / max(std[i], eps).

The mean and std vectors can be acquired in three ways.

First, they may be passed directly to this module on initialization.

Second, if mean and std were not specified, they can be estimated from the (biased) sample statistics of x. This is the same as unit normalization.

Third, they may be estimated from multiple instances of x by accumulating sufficient statistics with the accumulate() method, then writing the biased estimates with the store() method.

Parameters:
  • dim (int) – The dimension to be normalized. All other dimensions are considered

  • mean (Optional[Tensor]) – If set, a vector representing the population mean. The same size as std, if specified.

  • std (Optional[Tensor]) – If set, a vector representing the population standard deviation. The same size as mean, if specified.

  • eps (float) – A small non-negative floating-point value which ensures nonzero division if positive.

Call Parameters:

x (torch.Tensor) – A tensor whose dim-th dimension is the same size as mean and std. To be normalized.

Returns:

y (torch.Tensor) – The normalized tensor of the same shape as x.

Examples

>>> x = torch.arange(1000, dtype=torch.float).view(10, 10, 10)
>>> mean = x.flatten(0, 1).double().mean(0)
>>> std = x.flatten(0, 1).double().std(0, unbiased=False)
>>> y = MeanVarianceNormalization(-1, mean, std)(x)
>>> assert torch.allclose(y.flatten(0, 1).mean(0), torch.zeros(1))
>>> assert torch.allclose(y.flatten(0, 1).std(0, unbiased=False), torch.ones(1))
>>> mvn = MeanVarianceNormalization()
>>> y2 = mvn(x)
>>> assert torch.allclose(y, y2)
>>> for x_n in x:
...     mvn.accumulate(x_n)
>>> mvn.store()
>>> assert torch.allclose(mvn.mean, mean)
>>> assert torch.allclose(mvn.std, std)
>>> y2 = mvn(x)
>>> assert torch.allclose(y, y2)
class pydrobert.torch.modules.PadMaskedSequence(batch_first=False, padding_value=0.0)[source]

Select masked elements of tensor, then scatter into right-padded sequences

Parameters:
  • batch_first (bool) – Whether the first (or second) dimension of x is the batch dimension. The sequence dimension will be the second (or first).

  • padding_value (float) – The value to right-pad the remaining elements with along the sequence dimension.

Call Parameters:
  • x (torch.Tensor) – The input tensor. At least two dimensional.

  • mask (torch.Tensor) – A boolean tensor whose True values indicate that the associated element(s) of x should be included in the sequence. Broadcasts with the first two dimensions of x.

Returns:

  • x_ (torch.Tensor) – A tensor of the same shape as x such that, supposing i indexes the j-th True element of mask for batch index n:

    x_[j, n] = x[i, n]
    

    with the remaining values of x_ being padding_value.

  • lens (torch.Tensor) – A vector of the length of the batch dimension which counts the number of elements of x stored in x_ per batch element.

Examples

>>> x = torch.arange(100).view(10, 10)
>>> mask = (x % 3) == 0
>>> pad_masked_sequence = PadMaskedSequence(True, -1)
>>> x_, lens = pad_masked_sequence(x, mask)
>>> x_
tensor([[ 0,  3,  6,  9, -1, -1, -1, -1, -1, -1],
    [12, 15, 18, -1, -1, -1, -1, -1, -1, -1],
    [21, 24, 27, -1, -1, -1, -1, -1, -1, -1],
    [30, 33, 36, 39, -1, -1, -1, -1, -1, -1],
    [42, 45, 48, -1, -1, -1, -1, -1, -1, -1],
    [51, 54, 57, -1, -1, -1, -1, -1, -1, -1],
    [60, 63, 66, 69, -1, -1, -1, -1, -1, -1],
    [72, 75, 78, -1, -1, -1, -1, -1, -1, -1],
    [81, 84, 87, -1, -1, -1, -1, -1, -1, -1],
    [90, 93, 96, 99, -1, -1, -1, -1, -1, -1]])
>>> lens
tensor([4, 3, 3, 4, 3, 3, 4, 3, 3, 4])
>>> x = (x * 2).unsqueeze(2) + torch.arange(2)
>>> x_, lens = pad_masked_sequence(x, mask)
>>> x_[:1]
tensor([[[ 0,  1],
        [ 6,  7],
        [12, 13],
        [18, 19],
        [-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1]]])
class pydrobert.torch.modules.PadVariable(mode='constant', value=0.0)[source]

Pad variable-length input by a variable amount on each side

This module attempts to replicate the behaviour of torch.nn.functional.pad() on a tensor containing variable sequence lengths with variable amounts of padding.

Parameters:
  • mode (Literal['constant', 'reflect', 'replicate']) – How to pad the sequences. 'constant': fill the padding region with the value specified by value. 'reflect': padded values are reflections around the endpoints. For example, the first right-padded value of the n-th sequence would be x[n, lens[n] - 2, the third x[n, lens[n] - 3], and so on. replicate: padding duplicates the endpoints of each sequence. For example, the left-padded values of the n-th sequence would all be x[n, 0]; the right-padded values would be x[n, lens[n] - 1].

  • value (float) – The value to pad with when mode == 'constant'.

Call Parameters:
  • x (torch.Tensor) – A tensor of shape (N, T, *) where N is the batch index and T is the sequence index.

  • lens (torch.Tensor) – A long tensor of shape (N,) specifying the sequence lengths. Only the values in the range x[n, :lens[n]] are considered part of the sequence of batch element n.

  • pad (torch.Tensor) – A long tensor of shape (2, N) specifying how many elements at the start (pad[0]) and end (pad[1]) of each sequence.

Returns:

padded (torch.Tensor) –

A tensor of shape (N, T', *) such that, for a given batch index n:

padded[n, :pad[0, n]] = left padding
padded[n, pad[0,n]:pad[0,n] + lens[n]] = x[n, :lens[n]]
padded[n, pad[0,n] + lens[n]:pad[0,n] + lens[n] + pad[1, n]] = right padding

Raises:
  • NotImplementedError – If any value in pad[:, n] equals or exceeds lens[n] when mode == 'reflect'

  • RuntimeError – If any element in lens is less than 1 when mode == 'replicate'

Examples

>>> x = torch.arange(10).view(2, 5)
>>> x
tensor([[0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9]])
>>> lens = torch.tensor([3, 4])
>>> pad = torch.arange(4).view(2, 2)
>>> pad.t()  # [[0_left, 0_right], [1_left, 1_right]]
tensor([[0, 2],
        [1, 3]])
>>> y = pad_variable(x, lens, pad)  # constant w/ value 0
>>> y[0, :3 + 0 + 2]
tensor([0, 1, 2, 0, 0])
>>> y[1, :4 + 1 + 3]
tensor([0, 5, 6, 7, 8, 0, 0, 0])
>>> y = pad_variable(x, lens, pad, 'reflect')
>>> y[0, :3 + 0 + 2]
tensor([0, 1, 2, 1, 0])
>>> y[1, :4 + 1 + 3]
tensor([6, 5, 6, 7, 8, 7, 6, 5])
>>> y = pad_variable(x, lens, pad, 'replicate')
>>> y[0, :3 + 0 + 2]
tensor([0, 1, 2, 2, 2])
>>> y[1, :4 + 1 + 3]
tensor([5, 5, 6, 7, 8, 8, 8, 8])
class pydrobert.torch.modules.PolyharmonicSpline(order, regularization_weight=0.0, full_matrix=True)[source]

Guess values at query points using a learned polyharmonic spline

A spline estimates a function f : points -> values from a fixed number of training points/knots and the values of f at those points. It does that by solving a series of piecewise linear equations between knots such that the values at the knots match the given values (and some additional constraints depending on the spline).

This module is based on the interpolate_spline function from Tensorflow, which implements a Polyharmonic Spline. For technical details, consult the TF documentation.

Parameters:
  • order (int) – Order of the spline (> 0). 1 = linear. 2 = thin plate spline.

  • regularization_weight (float) – Weight placed on the regularization term. See TF for more info.

  • full_matrix (bool) – Whether to solve linear equations via a full concatenated matrix or a block decomposition. Setting to True better matches TF and appears to slightly improve numerical accuracy at the cost of twice the run time and more memory usage.

Call Parameters:
  • train_points (torch.Tensor) – A tensor of shape (N, T, I) representing the training points/knots for N different functions. N is the batch dimension, T is the number of training points, and I is the size of the vector input to f.

  • train_values (torch.Tensor) – A tensor of shape (N, T, O) of f evaluated on train_points. O is the size of the output vector of f.

  • query_points (torch.Tensor) – A tensor of shape (N, Q, I) representing the points you wish to have estimates for.

Returns:

query_values (torch.Tensor) – A tensor of shape (N, Q, O) consisting of the values estimated for query_points.

Raises:

RuntimeError – This module can return a :class`RuntimeError` when no unique spline can be estimated. In general, the spline will require at least I+1 non-degenerate points (linearly independent). See the Wikipedia entry on splnes for more info.

class pydrobert.torch.modules.RandomShift(prop, mode='reflect', value=0.0)[source]

Pad to the left and right of each sequence by a random amount

This layer is intended for training models which are robust to small shifts in some variable-length sequence dimension (e.g. speech recognition). It pads each input sequence with some number of elements at its beginning and end. The number of elements is randomly chosen but bounded above by some proportion of the input length specified by the user.

The amount of padding is dictated by the parameter prop this layer is initialized with. A proportion is a non-negative float dictating the maximum ratio of the original sequence length which may be padded, exclusive. prop can be a pair left, right for separate ratios of padding at the beginning and end of a sequence, or just one float if the proportions are the same. For example, prop=0.5 of a sequence of length 10 could result in a sequence of length between 10 and 18 inclusive since each side of the sequence could be padded with 0-4 elements (0.5 * 10 = 5 is an exclusive bound).

Parameters:
  • prop (Union[float, Tuple[float, float]]) –

  • mode (Literal['reflect', 'constant', 'replicate']) – The method with which to pad the input sequence.

  • value (float) – The constant with which to pad the sequence if mode is set to 'constant'.

Call Parameters:
  • input (torch.Tensor) – A tensor of shape (N, T, *) where N is the batch dimension and T is the sequence dimension; in_lens is a long tensor of shape (N,).

  • in_lens (torch.Tensor) – A tensor of shape (N,) containing the lengths of the sequences in in_. For batch element n, only the values input[n, in_lens[n]] are valid.

Returns:

  • out (torch.Tensor) – A tensor of the same type as input of shape (N, T', *).

  • out_lens (torch.Tensor) – A tensor of shape (N,) containing the lengths of the sequences in out.

Raises:

NotImplementedError – On initialization if mode is 'reflect' and a value in prop exceeds 1.0. Reflection currently requires the amount of padding does not exceed the original sequence length.

Notes

Padding is only applied if this layer is in training mode. If testing, out, out_lens = input, in_lens.

See also

pydrobert.torch.util.pad_variable

For more details on the different types of padding. Note the default mode is different between this and the function.

class pydrobert.torch.modules.SliceSpectData(policy='fixed', window_type='symmetric', valid_only=True, lobe_size=0)[source]

Determine slices of feature chunks according to a variety of policies

This module helps to chunk pydrobert.data.SpectDataLoader data (or other similarly-structured tensors) into smaller units by returning slices of that data. The input to this module and the means of determining those slices varies according to the policy specified (see the notes below for more details). The return values can then be used to slice the data.

Parameters:
  • policy (Literal['fixed', 'ali', 'ref']) – Specifies how to slice the data. If 'fixed', extract windows of fixed length at fixed intervals. If 'ali', use changes in frame-level alignments to determine segment boundaries and slice along those. If 'ref', use token segmentations as slices. See below for more info.

  • window_type (Literal['symmetric', 'causal', 'future']) – How the window will be constructed around the “middle unit” in the policy. In general 'symmetric' adds lobes to either side of the middle unit, 'causal' to the left (towards 0), 'future' to the right.

  • valid_only (bool) – What to do when a would-be slice passes over the length of the data. If True, any such slices are thrown out. If False, do something dictated by the policy which may preserve the invalid boundaries.

  • lobe_size (int) – Specifies the size of a lobe in the slice’s window. When the policy is 'fixed' or 'ref', the unit of lobe_size is a single frame. When policy is 'ali', the unit of lobe_size is a whole segment.

Call Parameters:
  • input (torch.Tensor) – A tensor of shape (N, T, *), where N is the batch dimension and T is the (maximum) sequence dimension. When policy is 'fixed', input should be the batch-first feature tensor feats from a pydrobert.data.SpectDataLoader. When 'ali', input should be the batch-first alis tensor. When 'ref', input should be the batch-first refs tensor with segment info.

  • in_lens (torch.Tensor, optional) – A long tensor of shape (N,) specifying the lengths of sequences in input. For the n-th batch element, only the elements input[n, :inputlens[n]] are considered. If unspecified, all sequences are assumed to be of length T. For the 'fixed' and 'ali' policies, this is the feat_lens tensor from a pydrobert.data.SpectDataLoader. When 'ref', it is the ref_lens tensor.

  • other_lens (torch.Tensor, optional) – An additional long tensor of shape (N,) specifying some other lengths, depending on the policy. It is currently only used in the 'ref' policy and takes the value feat_lens from a pydrobert.data.SpectDataLoader.

Returns:

  • slices (torch.Tensor) – A long tensor of shape (M, 2) storing the slices of all batch elements. M is the total number of slices. slices[m, 0] is the m-th slice’s start index (inclusive), while slices[m, 1] is the m-th slice’s end index (exclusive).

  • sources (torch.Tensor) – A long tensor of shape (M,) where sources[m] is the batch index of the m-th slice.

See also

ChunkBySlices

Can be used to chunk input using the returned slices (after reordering that input with sources)

Notes

If policy is 'fixed', slices are extracted at fixed intervals (lobe_size + 1) along the length of the data. input is assumed to be the data in question, e.g. the feats tensor in a pydrobert.data.SpectDataLoader, in batch-first order (although any tensor which matches its first two dimensions will do). in_lens may be used to specify the actual lengths of the input sequences if they were padded to fit in the same batch element. If window_type is 'symmetric', windows are of size 1 + 2 * lobe_size; otherwise, windows are of size 1 + lobe_size. When valid_only is True, slices start at index 0 and as many slices as can be fit fully within the sequences are returned. When valid_only is False slices are kept if their “middle” index lies before the end of the sequence with lobes clamped within the sequence. The “middle” index for the symmetric window is at slice[0] + window_size // 2; for the causal window it’s the last index of the window, slice[1] - 1; for the future window it’s the first, slice[0]. When valid_only is False, the initial slice’s offsets differ as well: for the symmetric case, it’s (lobe_size + 1) // 2 - window_size // 2; for the causal case, it’s -lobe_size; and the future case it’s still 0. As an example, given a sequence of length 8, the following are the slices under different configurations of the 'fixed' policy with a lobe_size of 2:

[[0, 5], [3, 8]]          # symmetric, valid_only
[[0, 3], [3, 6]]          # not symmetric, valid_only
[[-1, 4], [2, 6], [5, 9]] # symmetric, not valid_only
[[-2, 1], [1, 4], [4, 7]] # causal, not valid_only
[[0, 3], [3, 6], [6, 9]]  # future, not valid_only

If policy is 'ali', slices are extracted from the partition of the sequence induced by per-frame alignments. input is assumed to be the alignments in question, i.e. the batch-first alis tensor in a pydrobert.data.SpectDataLoader. in_lens may be used to specify the actual lengths of the input sequences if they were padded to fit in the same batch element. The segments are induced by ali as follows: a segment starts at index t whenever t == 0 or alis[n, t - 1] != alis[n, t]. Slice m is built from segment m by starting with the segment boundaries and possibly extending the start to the left (towards 0) or the end to the right (away from 0). If window_type is 'symmetric' or 'causal', the m-th segment’s start is set to the start of the (m - lobe_size)-th. If window_type is 'symmetric' or 'future', the segment’s end is set to the end of the (m + lobe_size)-th. Since there are a finite number of segments, sometimes either (m - lobe_size) or (m + lobe_size) will not exist. In that case and if only_valid is True, the slice is thrown out. If only_valid is False, the furthest segment from m in the same direction which also exists will be used. For example, with input[n] = [1] * 4 + [2] * 3 + [1] + [5] * 2, the following are the slices under different configurations of the 'ali' policy with a lobe_size of 1:

[[0, 8], [4, 10]]                   # symmetric, valid_only
[[0, 7], [4, 8], [7, 10]]           # not symmetric, valid_only
[[0, 7], [0, 8], [4, 10], [7, 10]]  # symmetric, not valid_only
[[0, 4], [0, 7], [4, 8], [7, 10]]   # causal, not valid_only
[[0, 7], [4, 8], [7, 10], [8, 10]]  # future, not valid_only

Finally, if policy is 'ref', slices are extracted from a transcription’s segment boundaries. input is assumed to be the token sequences in question, i.e. the batch-first refs tensor in a pydrobert.data.SpectDataLoader. input should be 3-dimensional with the third dimension of size 3: input[..., 0] the token sequence (ignored), input[..., 1] the segment starts (in frames), and input[..., 2] their ends. in_lens may be specified to give the length of the token sequences (i.e. ref_lens). In addition, the lengths of the sequences input is segmenting (in frames) may be passed via other_lens (i.e. feat_lens). The slices are built off the available segments. If window_type is 'causal', lobe_size is subtracted from all segments if 'future', lobe_size is added to all ends; if 'symmetric', both are applied. A segment may be discarded a few ways: if either the start or end frame is less than 0 (indicating missing segment information); if in_lens is set and the token segment is indexed past that length (input[n, t] for any t >= in_lens[n]); the starting frame of a segment (after padding) matches or exceeds the ending frame after padding (no empty or invalid slices); if valid_only is True and the padded start begins before index 0 or the padded end ends after other_lens; and if valid_only is False and the padded start begins after other_lens or ends at or before 0. For example, with input[n] = [[1, 0, 0], [2, 2, 3], [3, -1, 1], [4, 0, -1], [5, 3, 5], [6, 4, 4], in_lens[n] = 5`, other_lens[n] = 6, and lobe_size of 2, the following are the slices under different configurations of the 'ref' policy:

[[0, 5]]                  # symmetric, valid_only
[[0, 3], [1, 5]]          # causal, valid_only
[[0, 2], [2, 5]]          # future, valid_only
[[-2, 2], [0, 5], [1, 7]] # symmetric, not valid_only
[[0, 3], [1, 5]]          # causal, not valid_only
[[0, 2], [2, 5], [3, 7]]  # future, not valid_only
class pydrobert.torch.modules.SparseImageWarp(indexing='hw', field_interpolation_order=2, field_regularization_weight=0.0, field_full_matrix=True, pinned_boundary_points=0, dense_interpolation_mode='bilinear', dense_padding_mode='border', include_flow=True)[source]

Warp an image by specifying mappings between few control points

This module mirrors the behaviour of Tensorflow’s sparse_image_warp, except image is in NCHW order instead of NHWC order. For more details, please consult their documentation.

Parameters:
  • indexing (Literal['hw', 'wh']) – If indexing is "hw", source_points[n, m, 0] and dest_points[n, m, 0] index the height dimension in image and warped, respectively, and source_points[n, m, 1] and dest_points[n, m, 1] the width dimension. If indexing is "wh", the width dimension is the 0-index and height the 1.

  • field_interpolation_order (int) – The order of the polyharmonic spline used to interpolate the rest of the points from the control. See polyharmonic_spline() for more info.

  • field_regularization_weight (float) – The regularization weight of the polyharmonic spline used to interpolate the rest of the points from the control. See polyharmonic_spline() for more info.

  • field_full_matrix (bool) – Determines the method of calculating the polyharmonic spline used to interpolate the rest of the points from the control. See polyharmonic_spline() for more info.

  • pinned_boundary_points (int) – Dictates whether and how many points along the boundary of image are mapped identically to points in warped. This keeps the boundary of the image from being pulled into the interior of warped. When 0, no points are added. When 1, four points are added, one in each corner of the image. When k > 2, one point in each corner of the image is added, then k - 1 equidistant points along each of the four edges, totaling 4 * k points.

  • dense_interpolation_mode (Literal['bilinear', 'nearest']) – The method with which partial indices in the derived mapping are interpolated. See dense_image_warp() for more info.

  • dense_padding_mode (Literal['border', 'zero', 'reflection']) – What to do when points in the derived mapping fall outside of the boundaries. See dense_image_warp() for more info.

  • include_flow (bool) – If True, include the flow field flow interpolated from the control points in the return value.

Call Parameters:
  • image (torch.Tensor) – A source image of shape (N, C, H, W) where N is the batch dimension, C the channel dimension, H the image height, and W the image width.

  • source_points, dest_points (torch.Tensor) – Tensors of shape (N, M, 2), where M is the number of control points.

Returns:

warped (torch.Tensor or tuple of torch.Tensor) – If include_flow is False, warped is a warped image of shape (N, C, H, W). The point source_points[n, m, :] in image will be mapped to dest_points[n, m, :] in warped. If include_flow is True, warped is a pair of tensors warped_, flow where warped_ has the same definition as warped when include_flow is True and flow is a tensor of shape (N, H, W, 2). flow[n, h, w, :] is the flow for coordinates h, w in whatever order was specified by indexing. See DenseImageWarp for more details about flow.

Warning

When this module is scripted, its return type will be typing.Any. This reflects the fact that either warped is returned on its own (a tensor) or both warped_ and flow (a tuple). Use torch.jit.isinstance() for type refinement in subsequent scripting. Tracing will infer the correct type.

class pydrobert.torch.modules.SpecAugment(max_time_warp=80.0, max_freq_warp=0.0, max_time_mask=100, max_freq_mask=27, max_time_mask_proportion=0.04, num_time_mask=20, num_time_mask_proportion=0.04, num_freq_mask=2, interpolation_order=1)[source]

Perform warping/masking of time/frequency dimensions of filter bank features

SpecAugment [park2019] (and later [park2020]) is a series of data transformations for training data augmentation of time-frequency features such as Mel-scaled triangular filter bank coefficients.

Default parameter values are from [park2020].

Parameters:
  • max_time_warp (float) – A non-negative float specifying the maximum number of frames the chosen random frame can be shifted left or right by in step 1. Setting to 0 disables step 1.

  • max_freq_warp (float) – A non-negative float specifying the maximum number of coefficients the chosen random frequency coefficient index will be shifted up or down by in step 2. Setting to 0 disables step 2.

  • max_time_mask (int) – A non-negative integer specifying an absolute upper bound on the number of sequential frames in time that can be masked out by a single mask. The minimum of this upper bound and that from max_time_mask_proportion specifies the actual maximum. Setting this, max_time_mask_proportion, num_time_mask, or num_time_mask_proportion to 0 disables step 3.

  • max_freq_mask (int) – A non-negative integer specifying the maximum number of sequential coefficients in frequency that can be masked out by a single mask. Setting this or num_freq_mask to 0 disables step 4.

  • max_time_mask_proportion (float) – A value in the range \([0, 1]\) specifying a relative upper bound on the number of squential frames in time that can be masked out by a single mask. For batch element n, the upper bound is int(max_time_mask_poportion * length[n]). The minimum of this upper bound and that from max_time_mask specifies the actual maximum. Setting this, max_time_mask, num_time_mask, or num_time_mask_proportion to 0 disables step 4.

  • num_time_mask (int) – A non-negative integer specifying an absolute upper bound number of random masks in time per batch element to create. Setting this, num_time_mask_proportion, max_time_mask, or max_time_mask_proportion to 0 disables step 3. Drawn i.i.d. and may overlap.

  • num_time_mask_proportion (float) – A value in the range \([0, 1]\) specifying a relative upper bound on the number of time masks per element in the batch to create. For batch element n, the upper bound is int(num_time_mask_proportion * length[n]). The minimum of this upper bound and that from num_time_mask specifies the actual maximum. Setting this, num_time_mask, max_time_mask, or max_time_mask_proportion to 0 disables step 3. Drawn i.i.d. and may overlap.

  • num_freq_mask (int) – The total number of random masks in frequency per batch element to create. Setting this or max_freq_mask to 0 disables step 4. Drawn i.i.d. and may overlap.

  • interpolation_order (int) – Controls order of interpolation of warping. 1 = linear (default for [park2020]). 2 = thin plate (default for [park2019]). Higher orders are possible at increased computational cost.

Call Parameters:
  • feats (torch.Tensor) – A tensor of shape (N, T, F) where N is the batch dimension, T is the time (frames) dimension, and F is the frequency (coefficients per frame) dimension. The original feature tensor.

  • lengths (Optional[torch.Tensor], optional) – A long tensor of shape (N,) specifying the actual number of frames before right-padding per batch element. That is, for batch index n, only feats[n, :lengths[n]] are valid.

Returns:

new_feats (torch.Tensor) – The warped feats of shape (N, T, F) with some or all of the following operations performed in order independently per batch index:

  1. Choose a random frame along the time dimension. Warp feats such that feats[n, 0] and feats[n, lengths[n] - 1] are fixed, but that random frame gets mapped to a random new location a few frames to the left or right.

  2. Do the same for the frequency dimension.

  3. Mask out (zero) one or more random-width ranges of frames in a random location along the time dimension.

  4. Do the same for the frequency dimension.

The original SpecAugment implementation only performs steps 1, 3, and 4; step 2 is a trivial extension.

Notes

SpecAugment is only applied in training mode; in eval mode, new_feats == feats.

There are a few differences between this implementation of warping and those you might find online or described in the source paper [park2019]. These require some knowledge of what’s happening under the hood and are unlikely to change the way you use this function. We assume we’re warping in time, though the following applies to frequency warping as well.

First, the warp parameters are real- rather than integer-valued. You can set max_time_warp or max_freq_warp to 0.5 if you’d like. The shift value drawn between [0, max_time_warp] is also real-valued. Since the underlying warp relies on interpolation between partial indices anyways (the vast majority of tensor values will be the result of interpolation), there is no preference for integer-valued parameters from a computational standpoint. Further, real-valued warp parameters allow for a virtually infinite number of warps instead of just a few.

Finally, time warping is implemented by determining the transformation in one dimension (time) and broadcasting it across the other (frequency), rather than performing a two-dimensional warp. This is not in line with [park2019], but is with [park2020]. I have confirmed with the first author that the slight warping of frequency that occurred due to the 2D warp was unintentional.

class pydrobert.torch.modules.Warp1DGrid(max_length=None, interpolation_order=1)[source]

Interpolate grid values for a dimension of a grid_sample

This module determines a grid along a single dimension of a signal, image, volume, whatever.

Parameters:
  • max_length (Optional[int]) – A maximum length to which the grid will be padded. If unspecified, it will be taken to be lengths.max().ceil(). If grid is being plugged in to grid_sample(), ensure max_length matches the size of the dimension of the image being warped.

  • interpolation_order (int) – The degree of the spline used ot interpolate the grid.

Call Parameters:
  • src (torch.Tensor) – A tensor of shape (N,) containing the source points.

  • flow (torch.Tensor) – A tensor of shape (N,) contianing the corresponding flow fields for src.

  • lengths (torch.Tensor) – A tensor of shape (N,) specifying the number of valid indices along the dimension in question.

Returns:

grid (torch.Tensor) – A tensor of shape (N, max_length) which provides coodinates for one dimension of the grid passed to torch.nn.functional.grid_sample(). See the example below.

Notes

The return value grid assumes align_corners has been set to False in grid_sample().

The values in grid depend on the value of max_length. grid does not contain absolute pixel indices, instead mapping the range [0, max_length - 1] to the real values in [-1, 1]. Therefore, unless max_length is set to some fixed value, the n-th batch element grid[n] can differ according to the remaining values in length. However, the n-th batched image passed to grid_sample() should still be warped in a way (roughly) agnostic to surrounding batched images.

Language Models

class pydrobert.torch.modules.ExtractableSequentialLanguageModel(vocab_size)[source]

A SequentialLanguageModel whose prev values can be reordered on the batch idx

SequentialLanguageModel calls are on batched histories of paths hist. A SequentialLanguageModel which is also a ExtractableSequentialLanguageModel promises that, were we to rearrange and/or choose only some of those batch elements in hist to continue computations with, we can call the model’s extract_by_src() method to rearrange/extract the relevant values in prev or next_ in the same way.

class pydrobert.torch.modules.MixableSequentialLanguageModel(vocab_size)[source]

An ExtractableSequentialLanguageModel whose prev values can be mixed

In addition to the functionality of ExtractableSequentialLanguageModel, a MixableSequentialLanguageModel can also account for transformations from pairs of histories hist_a and hist_b into one new_hist such that each path in the latter is either from hist_a or hist_b. mix_by_mask() accomplishes this for the dictionaries prev and in_next.

class pydrobert.torch.modules.SequentialLanguageModel(vocab_size)[source]

A language model whose sequence probability is built sequentially

A language model provides the (log-)probability of a sequence of tokens. A sequential language model assumes that the probability distribution can be factored into a product of probabilities of the current token given the prior sequence, i.e. for token sequence \(\{w_s\}\)

\[P(w) = \prod_{s=1}^S P(w_s | w_{s - 1}, w_{s - 2}, \ldots w_1)\]

This definition includes statistical language models, such as n-grams, where the probability of the current token is based only on a fixed-length history, as well as recurrent neural language models [mikolov2010].

Parameters:

vocab_size (int) – The vocabulary size. Controls the size of the final output dimension, as well as what values of hist are considered in-vocabulary

Call Parameters:
  • hist (torch.Tensor) – A long tensor of shape (S, N) where S is the sequence dimension and N is the batch dimension. hist[:, n] is the n-th token prefix \((w^{(n)}_0, w^{(n)}_1, \ldots, w^{(n)}_{S-1})\).

  • prev (Dict[str, torch.Tensor], optional) – A dictionary of tensors which represents some additional state information which can be used in the computation. It may contain static input (e.g. a tensor of encoder output in neural machine translation) and/or dynamic input from prior calls to the LM (e.g. the previous hidden state in an RNN-based language model).

  • idx (Optional[Union[int, torch.Tensor]], optional) – If specified, it is either a single integer or a long tensor of shape (N,) specifying the indices of the tokens with which to return a distribution over. See the return value below.

Returns:

log_probs (torch.Tensor or tuple of torch.Tensor) – The return value changes depending on whether idx was specified.

If idx was not specified, the distributions over the next token over all prefixes in hist are returned. log_probs is a tensor of shape (S + 1, N, vocab_size) where each log_probs[s, n, v] equals \(\log P(w^{(n)}_{s} = v | w^{(n)}_{s - 1}, \ldots)\). That is, each distribution over types conditioned on each prefix of tokens (:0, :1, :2, etc.) is returned.

If idx was specified, the distributions over only the token at those indices are returned. log_probs is a pair of tensors log_probs_idx, next_. log_probs_idx is of shape (N, vocab_size) and log_probs[n, v] equals \(\log P(w^{(n)}_{idx[n]} = v | w^{(n)}_{idx[n]-1}, \ldots)\). That is, the distributions over the next type conditioned on token prefixes up to and excluding s = idx. next_ is a dictionary of tensors representing the updated state of the language model after computing these log probabilities, assuming prev represented the state at idx - 1.

Notes

When this module is scripted, its return type will be typing.Any. This reflects the fact that either log_probs is returned on its own (a tensor) or both log_probs and prev (a tuple). Use torch.jit.isinstance() for type refinement in subsequent scripting. Tracing will infer the correct type. Alternatively, one can use the methods update_input(), calc_idx_log_probs(), and calc_full_log_probs() to avoid ambiguity in the return type altogether.

This module has changed considerably since version 0.3.0. The primary changes are a) to replace the boolean switch full with idx; b) the inclusion of the prev argument for shared computations; c) the removal of eos, sos, and oov attributes; and d) replacing the more general signature of hist, (S, *), with (S, N). The former is strictly more powerful: the functionality of full=True is replicated by setting idx=None and full=False by setting idx=-1. The added functionality is intended to facilitate CTC decoding where prefixes stored in hist may be of different lengths. b) generalizes LMs by allowing additional input while also speeding up iterative computations. The removal of the eos and sos was due to a lack of generalizability. oov was removed because the user probably has to handle OOVs on her own when computing the loss.

See also

Language Modelling and Decoding

For a tutorial on how to build and use a language model.

class pydrobert.torch.modules.ExtractableShallowFusionLanguageModel(first, second, beta=0, first_prefix='first.', second_prefix='second.')[source]

ShallowFusionLanguageModel which is also an ExtractableSequentialLanguageModel

Both first and second must be ExtractableSequentialLanguageModel instances.

See also

ShallowFusionLanguageModel

For a description of shallow fusion and parameters. first and second may not be extractable, but neither is ShallowFusionLanguageModel.

MixableShallowFusionModel

A mixable subclass of this class. Applicable only if first and second are both MixableSequentialLanguageModel instances.

class pydrobert.torch.modules.LookupLanguageModel(vocab_size, sos, prob_dicts=None, destructive=False, logger=None, *, prob_list=None)[source]

Construct a backoff n-gram model from a fixed lookup table

An instance of this model will search for a stored log-probability of the current token given a fixed-length history in a lookup table. If it can’t find it, it backs off to a shorter length history and incurs a penalty:

\[\begin{split}Pr(w_t|w_{t-1},\ldots,w_{t-(N-1)}) = \begin{cases} Entry(w_{t-(N-1)}, w_{t-(N-1)+1}, \ldots, w_t) & \text{if } Entry(w_{t-(N-1)}, \ldots) > 0 \\ Backoff(w_{t-(N-1)}, \ldots, w_{t-1}) Pr(w_t|w_{t-1},\ldots,w_{t-(N-1)+1}) & \text{else} \end{cases}\end{split}\]

Missing entries are assumed to have value 0 and missing backoff penalties are assumed to have value 1.

Parameters:
  • sos (vocab_size) – The start of sequence token. If specified, any prefix with fewer tokens than the maximum order of n-grams minus 1 will be prepended up to that length with this token.

  • prob_dicts (Optional[List[Dict[Union[signedinteger, Tuple[signedinteger, ...]], floating]]]) – A list of dictionaries whose entry at index i corresponds to a table of i+1-gram log-probabilities. Keys must all be ids, not strings. Unigram keys are just ids; for n > 1 keys are tuples of ids with the latest word last. Values in the dictionary of the highest order n-gram dictionaries (last in prob_dicts) are the log-probabilities of the keys. Lower order dictionaries’ values are pairs of log-probability and log-backoff penalty. If prob_dicts is not specified, a unigram model with a uniform prior will be built.

  • destructive (bool) – If True, allows initialization to modify prob_dicts directly instead of making a fresh copy. Doing so can help reduce memory pressure.

  • logger (Optional[Logger]) – If specified, this logger will be used to report on the progress initializing this module.

Warning

This class differs considerably from its 0.3.0 version. prob_list was renamed to prob_dicts; prob_list is deprecated. sos became no longer optional. pad_sos_to_n was removed as an argument (implicitly true now). eos and oov were also removed as part of updates to SequentialLanguageModel. Finally, the underlying buffers of this model have changed in structure and name, invalidating any old saved state dictionaries.

JIT scripting is possible with this module, but not tracing.

Notes

Initializing an instance from an prob_dicts is expensive. prob_dicts is converted to a reverse trie (something like [heafield2011]) so that it takes up less space in memory, which can take some time.

Rather than re-initializing repeatedly, it is recommended you save and load this module’s state dict. load_state_dict() as been overridden to support loading different table sizes, avoiding the need for an accurate prob_dicts on initialization:

>>> # first time
>>> lm = LookupLanguageModel(vocab_size, sos, prob_dicts)  # slow
>>> state_dict = lm.state_dict()
>>> # save state dict, quit, startup, then reload state dict
>>> lm = LookupLanguageModel(vocab_size, sos)  # fast!
>>> lm.load_state_dict(state_dict)

See also

SequentialLanguageModel

A general description of language models, including call parameters

pydrobert.util.parse_arpa_lm

How to read a pretrained table of n-gram probabilities into prob_dicts. The parameter token2id should be specified to ensure id-based keys.

class pydrobert.torch.modules.MixableShallowFusionLanguageModel(first, second, beta=0, first_prefix='first.', second_prefix='second.')[source]

ShallowFusionLanguageModel which is also a MixableSequentialLanguageModel

Both first and second must be ExtractableSequentialLanguageModel instances.

See also

ShallowFusionLanguageModel

For a description of shallow fusion and parameters. first and second may not be mixable, but neither is ShallowFusionLanguageModel.

ExtractableSequentialLanguageModel

An extractable superclass of this class. Applicable if first and second are both ExtractableSequentialLanguageModel instances.

class pydrobert.torch.modules.ShallowFusionLanguageModel(first, second, beta=0.0, first_prefix='first.', second_prefix='second.')[source]

Language model combining two language models with shallow fusion

Shallow fusion [gulcehre2015] combines the predictions of two language models by taking the weighted sum of their log probabilities:

\[\log S(y_t=v|...) = \log P_{first}(y_t=v|...) + \beta \log P_{second}(y_t = v|...)\]

The resulting value \(log S(y_t=v)\) is not technically a probability.

Parameters:
  • first (SequentialLanguageModel) – The first language model

  • second (SequentialLanguageModel) – The second language model, whose log probabilities multiply with beta

  • beta (float) – The value \(\beta\)

  • first_prefix (str) – Elements of the state dict for first will have first_prefix prepended to their keys

  • second_prefix (str) – Like first_prefix, but for second

Warning

This class does not (and cannot) support JIT.

Notes

If you intend to perform shallow fusion between CTC logits and an external language model, you will not be able to do so via this class. CTC operates on an extended vocabulary while an external language model does not. Fortunately, CTCPrefixSearch has built-in support for shallow fusion. Consult that class for more information.

See also

MixableShallowFusionModel

A mixable subclass of this class. Applicable only if first and second are both MixableSequentialLanguageModel instances.

ExtractableShallowFusionModel

An extractable subclass of this class. Applicable only if first and second are both ExtractableSequentialLanguageModel instances.

Reinforcement Learning

class pydrobert.torch.modules.GumbelOneHotCategoricalRebarControlVariate(func, start_temp=0.1, start_eta=1.0)[source]

REBAR control variate for GumbelOneHotCategorical relaxation

REBAR [tucker2017] is a special case of the RELAX estimator [grathwohl2017] with a control variate that passes a temperature-based transformation of the relaxed sample to the function \(f\) the expectation is being taken over. That is:

\[c_{\lambda,\eta}(z) = \eta f(\sigma(z / \lambda))\]

For the GumbelOneHotCategorical distribution, \(\sigma\) is the softmax function.

Parameters:
  • func – The function \(f\). Must be able to accept relaxed samples.

  • start_temp (float) – The temperature the \(\lambda\) parameter is initialized to.

  • start_eta (float) – The coefficient the \(\eta\) parameter is initialzied to.

Variables:
  • log_temp – A scalar initialized to log(start_temp).

  • eta – A scalar initialized to start_eta.

Call Parameters:

z (torch.Tensor) – A tensor of shape (*, V) representing the relaxed sample.

Returns:

z_temp (torch.Tensor) – A tensor of the same shape as z storing the value \(c_{\lambda,\eta}(z)\).

Warning

This control variate can be traced but not scripted. Note that pydrobert.torch.estimators.RelaxEstimator is unable to be traced or scripted.

See also

pydrobert.torch.estimators.RelaxEstimator

For where to use this control variate.

class pydrobert.torch.modules.LogisticBernoulliRebarControlVariate(func, start_temp=0.1, start_eta=1.0)[source]

REBAR control variate for LogisticBernoulli relaxation

REBAR [tucker2017] is a special case of the RELAX estimator [grathwohl2017] with a control variate that passes a temperature-based transformation of the relaxed sample to the function \(f\) the expectation is being taken over. That is:

\[c_{\lambda,\eta}(z) = \eta f(\sigma(z / \lambda))\]

For the LogisticBernoulli distribution, \(\sigma\) is the sigmoid function.

Parameters:
  • func – The function \(f\). Must be able to accept relaxed samples.

  • start_temp (float) – The temperature the \(\lambda\) parameter is initialized to.

  • start_eta (float) – The coefficient the \(\eta\) parameter is initialzied to.

Variables:
  • log_temp – A scalar initialized to log(start_temp).

  • eta – A scalar initialized to start_eta.

Call Parameters:

z (torch.Tensor) – A tensor of shape (*) representing the relaxed sample.

Returns:

z_temp (torch.Tensor) – A tensor of the same shape as z storing the value \(c_{\lambda,\eta}(z)\).

Warning

This control variate can be traced but not scripted. Note that pydrobert.torch.estimators.RelaxEstimator is unable to be traced or scripted.

See also

pydrobert.torch.estimators.RelaxEstimator

For where to use this control variate.

class pydrobert.torch.modules.TimeDistributedReturn(gamma, batch_first)[source]

Accumulate future local rewards at every time step

In reinforcement learning, the return is defined as the sum of discounted future rewards. This function calculates the return for a given time step \(t\) as

\[R_t = \sum_{t'=t} \gamma^(t' - t) r_{t'}\]

Where \(r_{t'}\) gives the (local) reward at time \(t'\) and \(\gamma\) is the discount factor. \(\gamma \in [0, 1)\) implies convergence, but this is not enforced here.

Parameters:
  • gamma (float) – The discount factor \(\gamma\).

  • batch_first (bool) – Transposes the dimensions of r and R if True.

Call Parameters:

r (torch.Tensor) – A tensor of shape (T, N) of local rewards, where T is the sequence size and N is the batch size. The local rewards \(r\).

Returns:

R (torch.Tensor) – A tensor of shape (T, N) of the time-distributed rewards.

String Matching

class pydrobert.torch.modules.EditDistance(eos=None, include_eos=False, norm=False, batch_first=False, ins_cost=1.0, del_cost=1.0, sub_cost=1.0, warn=True)[source]

Compute an edit distance over a batch of references and hypotheses

An Edit Distance quantifies how dissimilar two token sequences are as the total cost of transforming a reference sequence into a hypothesis sequence. There are three operations that can be performed, each with an associated cost: adding an extra token to the reference, removing a token from the reference, or swapping a token in the reference with a token in the hypothesis.

Parameters:
  • eos (Optional[int]) – A special token in ref and hyp whose first occurrence in each batch indicates the end of a transcript. This allows for variable-length transcripts in the batch.

  • include_eos (bool) – Whether to include the first instance of eos found in both ref and hyp as valid tokens to be computed as part of the rate. This is useful when gauging if a model is learning to emit the eos properly, but is not usually included in an evaluation. Only the first eos per transcript is included.

  • norm (bool) – If True, will normalize the distance by the number of tokens in the reference sequence (making the returned value a divergence).

  • batch_first (bool) – If True, the first two dimensions of ref, hyp, and the return value are transposed from those above.

  • ins_cost (float) – The cost of an adding an extra token to a sequence in ref.

  • del_cost (float) – The cost of removing a token from a sequence in ref.

  • sub_cost (float) – The cost of swapping a token from ref with one from hyp.

  • warn (bool, optional) –

    Whether to display warnings on irregularities. Currently, this can happen in three ways:

    1. If True and ins_cost, del_cost, or sub_cost is not 1, a warning about a difference in computations will be raised. See the below warning for more info.

    2. If True and norm is True, will warn when a reference transcription has zero length

    3. If eos is set and include_eos is True, will warn when a transcript does not include an eos symbol.

Call Parameters:
  • ref (torch.Tensor) – A long tensor of shape (R, N) where R is the reference sequence dimension and N is the batch dimension. Stores the reference (gold-standard) sequences.

  • hyp (torch.Tensor) – A long tensor of shape (H, N) where H is the hypothesis sequence dimension. Stores the hypothesis (machine-generated) sequences.

Returns:

ed (torch.Tensor) – A tensor of shape (N,) of the edit distances.

Notes

This module returns identical values (modulo a bug fix) to error_rate() up to v0.3.0 (though the default of norm has changed to False). For more details on the distinction between this module and the new ErrorRate(), please see that module’s documentation.

class pydrobert.torch.modules.ErrorRate(eos=None, include_eos=False, norm=True, batch_first=False, ins_cost=1.0, del_cost=1.0, sub_cost=1.0, warn=True)[source]

Calculate error rates over a batch of references and hypotheses

An error rate is the total number of insertions, deletions, and substitutions between a reference (gold-standard) and hypothesis (generated) transcription, normalized by the number of elements in a reference. Consult the Wikipedia article on the Levenshtein distance for more information.

Parameters:
  • eos (Optional[int]) – A special token in ref and hyp whose first occurrence in each batch indicates the end of a transcript. This allows for variable-length transcripts in the batch.

  • include_eos (bool) – Whether to include the first instance of eos found in both ref and hyp as valid tokens to be computed as part of the rate. This is useful when gauging if a model is learning to emit the eos properly, but is not usually included in an evaluation. Only the first eos per transcript is included.

  • norm (bool) – If True, will normalize the distance by the number of tokens in the reference sequence (making the returned value a divergence).

  • batch_first (bool) – If True, the first two dimensions of ref, hyp, and the return value are transposed from those above.

  • ins_cost (float) – The cost of an adding an extra token to a sequence in ref.

  • del_cost (float) – The cost of removing a token from a sequence in ref.

  • sub_cost (float) – The cost of swapping a token from ref with one from hyp.

  • warn (bool, optional) –

    Whether to display warnings on irregularities. Currently, this can happen in three ways:

    1. If True and ins_cost, del_cost, or sub_cost is not 1, a warning about a difference in computations will be raised. See the below warning for more info.

    2. If True and norm is True, will warn when a reference transcription has zero length

    3. If eos is set and include_eos is True, will warn when a transcript does not include an eos symbol.

Call Parameters:
  • ref (torch.Tensor) – A long tensor of shape (R, N) where R is the reference sequence dimension and N is the batch dimension. Stores the reference (gold-standard) sequences.

  • hyp (torch.Tensor) – A long tensor of shape (H, N) where H is the hypothesis sequence dimension. Stores the hypothesis (machine-generated) sequences.

Returns:

ed (torch.Tensor) – A tensor of shape (N,) of the error rates.

Warning

Up to and including v0.3.0, error_rate() computed a normalized Edit distance instead of an error rate. The latter can be considered the total weighted cost of insertions, deletions, and substitutions (as per ins_cost, del_cost, and sub_cost), whereas the former is the sum of the number of mistakes. The old behaviour of returning the cost is now in edit_distance() and EditDistance (though norm is False by default). For speech recognition evaluation, this module or error_rate() is the one to use. However, if you are using the default costs, ins_cost == del_cost == sub_cost == 1, there should be no numerical difference between the two.

class pydrobert.torch.modules.FillAfterEndOfSequence(eos, dim=0, fill=None)[source]

Fill after the first end-of-sequence token with a value

Many Natural Language Processing tasks involve variable-length sequences ending with special “end-of-sequence” (eos) tokens. This module finds the first instance of eos and pads everything after that along the dim dimension with the value of fill.

Parameters:
  • eos (int) – The id of the end-of-sequence token.

  • dim (int) – The sequence dimension of tokens.

  • fill (Optional[float]) – The value to fill with. If unset, set to eos.

Call Parameters:
  • tokens (torch.Tensor) – The token sequences. Of arbitrary shape, but must have dimension dim.

  • value (Optional[torch.Tensor], optional) – value may be optionally specified as a tensor other than tokens to fill. It must broadcast with tokens if specified. Otherwise value will be assumed to be tokens.

Returns:

out (torch.Tensor) – A tensor matching tokens (or values broadcasted with tokens, if values was specified) except beyond the first instance of eos in tokens, after which is fill.

Examples

>>> T = 10
>>> tokens = torch.arange(T)
>>> tokens
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> fill_after_eos = FillAfterEndOfSequence(eos=T // 2, fill=-1)
>>> out = fill_after_eos(tokens)
>>> out
tensor([ 0,  1,  2,  3,  4,  5, -1, -1, -1, -1])
>>> logits = torch.eye(T)
tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
    [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
    [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
    [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
    [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
    [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
    [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
    [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
    [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
    [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])
>>> out = fill_after_eos(tokens.unsqueeze(1), logits)
>>> out
tensor([[ 1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
    [ 0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
    [ 0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
    [ 0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.],
    [ 0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.],
    [ 0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.],
    [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
    [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
    [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
    [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.]])
class pydrobert.torch.modules.HardOptimalCompletionDistillationLoss(eos=None, include_eos=True, batch_first=False, ins_cost=1.0, del_cost=1.0, sub_cost=1.0, weight=None, reduction='mean', ignore_index=-100)[source]

A categorical loss function over optimal next tokens

Optimal Completion Distillation (OCD) [sabour2018] tries to minimize the train/test discrepancy in transcriptions by allowing seq2seq models to generate whatever sequences they want, then assigns a per-step loss according to whatever next token would set the model on a path that minimizes the edit distance in the future.

In its “hard” version, the version used in the paper, the OCD loss function is simply a categorical cross-entropy loss of each hypothesis token’s distribution versus those optimal next tokens, averaged over the number of optimal next tokens:

\[loss(logits_h) = \frac{-\log Pr(s_h|logits_t)}{|S_h|}\]

Where \(s_h \in S_h\) are tokens from the set of optimal next tokens given \(hyp_{\leq h}\) and ref. The loss is decoupled from an exact prefix of ref, meaning that hyp can be longer or shorter than ref.

Parameters:
  • eos (Optional[int]) – A special token in ref and hyp whose first occurrence in each batch indicates the end of a transcript. This allows for variable-length transcripts in the batch.

  • include_eos (bool) – Whether to include the first instance of eos found in both ref and hyp as valid tokens to be computed as part of the rate. This is useful when gauging if a model is learning to emit the eos properly, but is not usually included in an evaluation. Only the first eos per transcript is included.

  • batch_first (bool) – If True, the first two dimensions of ref, hyp, and the return value are transposed from those above.

  • ins_cost (float) – The cost of an adding an extra token to a sequence in ref.

  • del_cost (float) – The cost of removing a token from a sequence in ref.

  • sub_cost (float) – The cost of swapping a token from ref with one from hyp.

  • reduction (Literal['mean', 'sum', 'none']) – Specifies the reduction to be applied to the output. 'none': no reduction will be applied. 'sum': the output will be summed. 'mean': the output will be averaged.

  • ignore_index (int) – Specify a target value that is ignored and does not contribute to the input gradient. Should not be set to eos when include_eos is True.

Call Parameters:
  • logits (torch.Tensor) – A tensor of shape (H, N, V) where H is the hypothesis sequence dimension, N is the batch dimension, and V is the vocabulary size. Stores the unnormalized log-probabilities over the next token of each prefix (except the last) within hyp.

  • ref (torch.Tensor) – A long tensor of shape (R, N) where R is the reference sequence dimension. Stores the reference (gold-standard) sequences.

  • hyp (torch.Tensor) – A long tensor of shape (H, N). Stores the hypothesis (machine-generated) sequences.

Returns:

loss (torch.Tensor) – The loss. If reduction is 'sum' or 'mean', it is a scalar value. Otherwise of shape (H, N).

See also

pydrobert.torch.util.optimal_completion

Used to determine the optimal next token set \(S\)

pydrobert.torch.util.random_walk_advance

For producing a random hyp based on logits if the underlying model producing logits is auto-regressive. Also provides an example of sampling non-auto-regressive models

class pydrobert.torch.modules.MinimumErrorRateLoss(eos=None, include_eos=True, sub_avg=True, batch_first=False, norm=True, ins_cost=1.0, del_cost=1.0, sub_cost=1.0, reduction='mean')[source]

Error rate expectation normalized over some number of transcripts

Proposed in [prabhavalkar2018] though similar ideas had been explored previously. Given a subset of all possible token sequences and their associated probability mass over that population, this loss calculates the probability mass normalized over the subset, then calculates the expected error rate over that normalized distribution. That is, given some sequences \(s \in S \subseteq P\), the loss for a given reference transcription \(s^*\) is

\[\mathcal{L}(s, s^*) = \frac{Pr(s) ER(s, s^*)}{\sum_{s'} Pr(s')}\]

This is an exact expectation over \(S\) but not over \(P\). The larger the mass covered by \(S\), the closer the expectation is to the population - especially so for an n-best list (though it would be biased).

Parameters:
  • eos (Optional[int]) – A special token in ref and hyp whose first occurrence in each batch indicates the end of a transcript. This allows for variable-length transcripts in the batch.

  • include_eos (bool) – Whether to include the first instance of eos found in both ref and hyp as valid tokens to be computed as part of the rate. This is useful when gauging if a model is learning to emit the eos properly, but is not usually included in an evaluation. Only the first eos per transcript is included.

  • sub_avg (bool) – Whether to subtract the average error rate from each pathwise error rate.

  • batch_first (bool) – If True, the first two dimensions of ref, hyp, and the return value are transposed from those above.

  • norm (bool) – If True, will normalize the distance by the number of tokens in the reference sequence (making the returned value a divergence).

  • ins_cost (float) – The cost of an adding an extra token to a sequence in ref.

  • del_cost (float) – The cost of removing a token from a sequence in ref.

  • sub_cost (float) – The cost of swapping a token from ref with one from hyp.

  • reduction (Literal['mean', 'none', 'sum']) – Specifies the reduction to be applied to the output. 'none': no reduction will be applied. 'sum': the output will be summed. 'mean': the output will be averaged.

Call Parameters:
  • log_probs (torch.Tensor) – A tensor of shape (N, M) where N is the batch size and M is the number of samples providing the log joint probabilities of every sample path.

  • ref (torch.Tensor) – A tensor of either of shape (R, N) or (R, N, M) where R is the maximum reference length containing the reference (gold-standard) transcriptions. Whether ref is 2D or 3D changes how the loss is calculated.

  • hyp (torch.Tensor) – A long tensor of shape (H, N, M) where H is the maximum hypothesis size containing the hypothesis (machine-generated) transcriptions.

  • warn (bool, optional) – Whether to display warnings on irregularities. Currently, this can happen in three ways:

    1. If True and ins_cost, del_cost, or sub_cost is not 1, a warning about a difference in computations will be raised. See the below warning for more info.

    2. If True and norm is True, will warn when a reference transcription has zero length

    3. If eos is set and include_eos is True, will warn when a transcript does not include an eos symbol.

Returns:

loss (torch.Tensor) – The loss. If reduction is 'sum' or 'mean', it is a scalar value. Otherwise of shape (N,M). If ref is 2D, the loss for sample m of batch element n is

\[loss_{n, m} = SoftMax(log\_probs)[ER(hyp_{n, m}, ref_n) - \mu_n]\]

where where \(\mu_n\) is the average error rate for the M hypotheses in batch element n. \(\mu_n\) is dropped if sub_avg is True. If ref is 3D, each hypothesis is compared against a unique reference:

\[loss_{n, m} = SoftMax(log\_probs)[ER(hyp_{n, m}, ref_{n,m}) - \mu_n]\]

Notes

A previous version of this module incorporated a Maximum Likelihood Estimate (MLE) into the loss as in [prabhavalkar2018], which required logits instead of log_probs. This was overly complicated, given the user can easily incorporate the additional loss term herself by using torch.nn.CrossEntropyLoss. Take a look at the example below for how to recreate this

Examples

Assume here that logits is the output of some neural network, and that hyp has somehow been produced from that (e.g. a beam search or random walk). We combine this loss function with a cross-entropy/MLE term to sort-of recreate [prabhavalkar2018].

>>> from pydrobert.torch.util import sequence_log_probs
>>> steps, batch_size, num_classes, eos, padding = 30, 20, 10, 0, -1
>>> samples, lmb = 10, .01
>>> logits = torch.randn(
...     steps, samples, batch_size, num_classes, requires_grad=True)
>>> hyp = torch.randint(num_classes, (steps, samples, batch_size))
>>> ref_lens = torch.randint(1, steps + 1, (batch_size,))
>>> ref_lens[0] = steps
>>> ref = torch.nn.utils.rnn.pad_sequence(
...     [torch.randint(1, num_classes, (x,)) for x in ref_lens],
...     padding_value=padding,
... )
>>> ref[ref_lens - 1, range(batch_size)] = eos
>>> ref = ref.unsqueeze(1).repeat(1, samples, 1)
>>> mer = MinimumErrorRateLoss(eos=eos)
>>> mle = torch.nn.CrossEntropyLoss(ignore_index=padding)
>>> log_probs = sequence_log_probs(logits, hyp, eos=eos)
>>> l = mer(log_probs, ref, hyp)
>>> l = l + lmb * mle(logits.view(-1, num_classes), ref.flatten())
>>> l.backward()

See also

pydrobert.torch.util.beam_search_advance

For getting an n-best list into hyp and some log_probs.

pydrobert.torch.util.random_walk_advance

For getting a random sample into hyp

pydrobert.torch.util.sequence_log_probs

For converting token log probs (or logits) to sequence log probs

class pydrobert.torch.modules.OptimalCompletion(eos=None, include_eos=True, batch_first=False, ins_cost=1.0, del_cost=1.0, sub_cost=1.0, padding=-100, exclude_last=False, warn=True)[source]

Return a mask of next tokens of a minimum edit distance prefix

Parameters:
  • eos (Optional[int]) – A special token in ref and hyp whose first occurrence in each batch indicates the end of a transcript. This allows for variable-length transcripts in the batch.

  • include_eos (bool) – Whether to include the first instance of eos found in both ref and hyp as valid tokens to be computed as part of the rate. This is useful when gauging if a model is learning to emit the eos properly, but is not usually included in an evaluation. Only the first eos per transcript is included.

  • batch_first (bool) – If True, the first two dimensions of ref, hyp, and the return value are transposed from those above.

  • ins_cost (float) – The cost of an adding an extra token to a sequence in ref.

  • del_cost (float) – The cost of removing a token from a sequence in ref.

  • sub_cost (float) – The cost of swapping a token from ref with one from hyp.

  • padding (int) – The value to right-pad unequal-length sequences with. Defauls to pydrobert.torch.config.INDEX_PAD_VALUE.

  • exclude_last (bool) – If true, will exclude the final prefix, consisting of the entire transcript, from the return value. It will be of shape (H, N, U)

  • warn (bool, optional) –

    Whether to display warnings on irregularities. Currently, this can happen in three ways:

    1. If True and ins_cost, del_cost, or sub_cost is not 1, a warning about a difference in computations will be raised. See the below warning for more info.

    2. If True and norm is True, will warn when a reference transcription has zero length

    3. If eos is set and include_eos is True, will warn when a transcript does not include an eos symbol.

Call Parameters:
  • ref (torch.Tensor) – A long tensor of shape (R, N) where R is the reference sequence dimension and N is the batch dimension. Stores the reference (gold-standard) sequences.

  • hyp (torch.Tensor) – A long tensor of shape (H, N) where H is the hypothesis sequence dimension. Stores the hypothesis (machine-generated) sequences.

Returns:

optimals (torch.Tensor) – A long tensor of shape (H + 1, N, U) where U <= R of the unique tokens that could be added to each prefix of the hypothesis such that some remaining suffix concatenated to the prefix would result in a minimal edit distance. See below for an example.

Examples

Consider the reference text “foot” and the hypothesis text “bot”. The below shows the matrix used to calculate edit distances between them:

\ _ f o o t
_ 0 1 2 3 4
b 1 1 2 3 4
o 2 2 1 2 3
t 3 3 2 2 2

If prefix_len == 0, then the prefix is “”, and “f” (from the suffix “foot”) is the only subsequent token that would not increase the edit distance from that of the prefix (0). If prefix_len == 1, then the prefix is “b”. To arrive at the minimum edit distance for “b”, one either treats “b” as an insertion or a substitution for “f”, yielding suffixes “foot” and “oot”. Thus, the subsequent token could be “f” or “o”. For the prefix “bo”, the minimum edit distance is achieved by first substituting “f” for “b”, then substituting “o” for “o”, resulting in the suffix “ot” and the next optimal character “o”. Finally, for prefix_len == 3 and prefix “bot”, there are many operations that can produce the minimum edit distance of 2, resulting in one of the suffixes “ot”, “t”, and “”. The latter suffix requires no more tokens and so any operation would increase the edit distance. Thus the optimal next tokens could be “o” or “t”.

Plugging “foot” and “bot” into this function, we get the prefixes:

>>> ref_text, hyp_text = "foot", "bot"
>>> ref = torch.tensor([ord(c) for c in ref_text]).unsqueeze(1)
>>> hyp = torch.tensor([ord(c) for c in hyp_text]).unsqueeze(1)
>>> optimal = optimal_completion(ref, hyp).squeeze(1)
>>> for prefix_len, o_for_pr in enumerate(optimal):
...     o_for_pr = o_for_pr.masked_select(o_for_pr.ge(0)).tolist()
...     print('prefix={}: {}'.format(
...         hyp_text[:prefix_len], ','.join([chr(i) for i in o_for_pr])))
prefix=: f
prefix=b: f,o
prefix=bo: o
prefix=bot: o,t

See also

pydrobert.torch.layers.HardOptimalCompletionDistillationLoss

A loss function that uses these optimal completions to train a model

class pydrobert.torch.modules.PrefixEditDistances(eos=None, include_eos=True, norm=False, batch_first=False, ins_cost=1.0, del_cost=1.0, sub_cost=1.0, padding=-100, exclude_last=False, warn=True)[source]

Compute the edit distance between ref and each prefix of hyp

Parameters:
  • eos (Optional[int]) – A special token in ref and hyp whose first occurrence in each batch indicates the end of a transcript. This allows for variable-length transcripts in the batch.

  • include_eos (bool) – Whether to include the first instance of eos found in both ref and hyp as valid tokens to be computed as part of the rate. This is useful when gauging if a model is learning to emit the eos properly, but is not usually included in an evaluation. Only the first eos per transcript is included.

  • norm (bool) – If True, will normalize the distance by the number of tokens in the reference sequence (making the returned value a divergence).

  • batch_first (bool) – If True, the first two dimensions of ref, hyp, and the return value are transposed from those above.

  • ins_cost (float) – The cost of an adding an extra token to a sequence in ref.

  • del_cost (float) – The cost of removing a token from a sequence in ref.

  • sub_cost (float) – The cost of swapping a token from ref with one from hyp.

  • padding (int) – The value to right-pad unequal-length sequences with. Defauls to pydrobert.torch.config.INDEX_PAD_VALUE.

  • exclude_last (bool) – If true, will exclude the final prefix, consisting of the entire transcript, from the return value. It will be of shape (H, N, U)

  • warn (bool, optional) –

    Whether to display warnings on irregularities. Currently, this can happen in three ways:

    1. If True and ins_cost, del_cost, or sub_cost is not 1, a warning about a difference in computations will be raised. See the below warning for more info.

    2. If True and norm is True, will warn when a reference transcription has zero length

    3. If eos is set and include_eos is True, will warn when a transcript does not include an eos symbol.

  • eos – A special token in ref and hyp whose first occurrence in each batch indicates the end of a transcript. This allows for variable-length transcripts in the batch.

  • include_eos – Whether to include the first instance of eos found in both ref and hyp as valid tokens to be computed as part of the rate. This is useful when gauging if a model is learning to emit the eos properly, but is not usually included in an evaluation. Only the first eos per transcript is included.

  • norm – If True, will normalize the distance by the number of tokens in the reference sequence (making the returned value a divergence).

  • batch_first – If True, the first two dimensions of ref, hyp, and the return value are transposed from those above.

  • ins_cost – The cost of an adding an extra token to a sequence in ref.

  • del_cost – The cost of removing a token from a sequence in ref.

  • sub_cost – The cost of swapping a token from ref with one from hyp.

  • padding – The value to right-pad unequal-length sequences with. Defauls to pydrobert.torch.config.INDEX_PAD_VALUE.

  • exclude_last – If true, will exclude the final prefix, consisting of the entire transcript, from the return value. It will be of shape (H, N, U)

  • warn

    Whether to display warnings on irregularities. Currently, this can happen in three ways:

    1. If True and ins_cost, del_cost, or sub_cost is not 1, a warning about a difference in computations will be raised. See the below warning for more info.

    2. If True and norm is True, will warn when a reference transcription has zero length

    3. If eos is set and include_eos is True, will warn when a transcript does not include an eos symbol.

Call Parameters:
  • ref (torch.Tensor) – A long tensor of shape (R, N) where R is the reference sequence dimension and N is the batch dimension. Stores the reference (gold-standard) sequences.

  • hyp (torch.Tensor) – A long tensor of shape (H, N) where H is the hypothesis sequence dimension. Stores the hypothesis (machine-generated) sequences.

Returns:

prefix_eds (torch.Tensor) – A tensor of shape (H + 1, N) of the edit distances for each prefix of each hypothesis, starting from the empty prefix.

Notes

This module returns identical values (modulo a bug fix) to prefix_error_rates() (and PrefixErrorRates) up to v0.3.0 (though the default of norm has changed to False). For more details on the distinction between this module and the new prefix_error_rates(), please consult the documentation of ErrorRate.

class pydrobert.torch.modules.PrefixErrorRates(eos=None, include_eos=True, norm=True, batch_first=False, ins_cost=1.0, del_cost=1.0, sub_cost=1.0, padding=-100, exclude_last=False, warn=True)[source]

Compute the error rate between ref and each prefix of hyp

Parameters:
  • eos (Optional[int]) – A special token in ref and hyp whose first occurrence in each batch indicates the end of a transcript. This allows for variable-length transcripts in the batch.

  • include_eos (bool) – Whether to include the first instance of eos found in both ref and hyp as valid tokens to be computed as part of the rate. This is useful when gauging if a model is learning to emit the eos properly, but is not usually included in an evaluation. Only the first eos per transcript is included.

  • norm (bool) – If True, will normalize the distance by the number of tokens in the reference sequence (making the returned value a divergence).

  • batch_first (bool) – If True, the first two dimensions of ref, hyp, and the return value are transposed from those above.

  • ins_cost (float) – The cost of an adding an extra token to a sequence in ref.

  • del_cost (float) – The cost of removing a token from a sequence in ref.

  • sub_cost (float) – The cost of swapping a token from ref with one from hyp.

  • padding (int) – The value to right-pad unequal-length sequences with. Defauls to pydrobert.torch.config.INDEX_PAD_VALUE.

  • exclude_last (bool) – If true, will exclude the final prefix, consisting of the entire transcript, from the return value. It will be of shape (H, N, U)

  • warn (bool, optional) –

    Whether to display warnings on irregularities. Currently, this can happen in three ways:

    1. If True and ins_cost, del_cost, or sub_cost is not 1, a warning about a difference in computations will be raised. See the below warning for more info.

    2. If True and norm is True, will warn when a reference transcription has zero length

    3. If eos is set and include_eos is True, will warn when a transcript does not include an eos symbol.

Call Parameters:
  • ref (torch.Tensor) – A long tensor of shape (R, N) where R is the reference sequence dimension and N is the batch dimension. Stores the reference (gold-standard) sequences.

  • hyp (torch.Tensor) – A long tensor of shape (H, N) where H is the hypothesis sequence dimension. Stores the hypothesis (machine-generated) sequences.

Returns:

prefix_ers (torch.Tensor) – A tensor of shape (H + 1, N) containing the error rates for each prefix of each hypothesis, starting from the empty prefix.

Warning

The values returned by prefix_error_rates() (and thus this module) changed after v0.3.0. The old behaviour can be found in PrefixEditDistances (though with norm defaulting to False). Consult the warning in ErrorRate for more info.