# Copyright 2023 Sean Robertson
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing_extensions import Literal
import warnings
from typing import Any, Dict, Optional, Tuple, Union, overload, List
from logging import Logger
import torch
import numpy as np
from . import argcheck
from ._compat import script
from ._wrappers import proxy
try:
from sortedcontainers import SortedList # type: ignore
def insort_left(sl: SortedList, x):
sl.add(x)
except ImportError:
from bisect import insort_left
SortedList = list
ProbDicts = List[
Dict[Union[np.signedinteger, Tuple[np.signedinteger, ...]], np.floating]
]
[docs]
class SequentialLanguageModel(torch.nn.Module, metaclass=abc.ABCMeta):
r"""A language model whose sequence probability is built sequentially
A language model provides the (log-)probability of a sequence of tokens. A
sequential language model assumes that the probability distribution can be factored
into a product of probabilities of the current token given the prior sequence, i.e.
for token sequence :math:`\{w_s\}`
.. math::
P(w) = \prod_{s=1}^S P(w_s | w_{s - 1}, w_{s - 2}, \ldots w_1)
This definition includes statistical language models, such as n-grams, where the
probability of the current token is based only on a fixed-length history, as well as
recurrent neural language models [mikolov2010]_.
Parameters
----------
vocab_size
The vocabulary size. Controls the size of the final output dimension,
as well as what values of `hist` are considered in-vocabulary
Call Parameters
---------------
hist : torch.Tensor
A long tensor of shape ``(S, N)`` where ``S`` is the sequence dimension and
``N`` is the batch dimension. ``hist[:, n]`` is the n-th token prefix
:math:`(w^{(n)}_0, w^{(n)}_1, \ldots, w^{(n)}_{S-1})`.
prev : Dict[str, torch.Tensor], optional
A dictionary of tensors which represents some additional state information which
can be used in the computation. It may contain static input (e.g. a tensor of
encoder output in neural machine translation) and/or dynamic input from prior
calls to the LM (e.g. the previous hidden state in an RNN-based language model).
idx : Optional[Union[int, torch.Tensor]], optional
If specified, it is either a single integer or a long tensor of shape ``(N,)``
specifying the indices of the tokens with which to return a distribution over.
See the return value below.
Returns
-------
log_probs : torch.Tensor or tuple of torch.Tensor
The return value changes depending on whether `idx` was specified.
If `idx` was not specified, the distributions over the next token over all
prefixes in `hist` are returned. `log_probs` is a tensor of shape ``(S + 1, N,
vocab_size)`` where each ``log_probs[s, n, v]`` equals :math:`\log P(w^{(n)}_{s}
= v | w^{(n)}_{s - 1}, \ldots)`. That is, each distribution over types
conditioned on each prefix of tokens (``:0``, ``:1``, ``:2``, etc.) is returned.
If `idx` was specified, the distributions over only the token at those indices
are returned. `log_probs` is a pair of tensors ``log_probs_idx, next_``.
`log_probs_idx` is of shape ``(N, vocab_size)`` and ``log_probs[n, v]`` equals
:math:`\log P(w^{(n)}_{idx[n]} = v | w^{(n)}_{idx[n]-1}, \ldots)`. That is, the
distributions over the next type conditioned on token prefixes up to and
excluding ``s = idx``. `next_` is a dictionary of tensors representing the
updated state of the language model after computing these log probabilities,
assuming `prev` represented the state at ``idx - 1``.
Notes
-----
When this module is scripted, its return type will be :class:`typing.Any`. This
reflects the fact that either `log_probs` is returned on its own (a tensor) or both
`log_probs` and `prev` (a tuple). Use :func:`torch.jit.isinstance` for type
refinement in subsequent scripting. Tracing will infer the correct type.
Alternatively, one can use the methods :func:`update_input`,
:func:`calc_idx_log_probs`, and :func:`calc_full_log_probs` to avoid ambiguity in
the return type altogether.
This module has changed considerably since version 0.3.0. The primary changes are a)
to replace the boolean switch `full` with `idx`; b) the inclusion of the `prev`
argument for shared computations; c) the removal of `eos`, `sos`, and `oov`
attributes; and d) replacing the more general signature of `hist`, ``(S, *)``, with
``(S, N)``. The former is strictly more powerful: the functionality of ``full=True``
is replicated by setting ``idx=None`` and ``full=False`` by setting ``idx=-1``. The
added functionality is intended to facilitate CTC decoding where prefixes stored in
`hist` may be of different lengths. b) generalizes LMs by allowing additional input
while also speeding up iterative computations. The removal of the `eos` and `sos`
was due to a lack of generalizability. `oov` was removed because the user probably
has to handle OOVs on her own when computing the loss.
See Also
--------
:ref:`lm`
For a tutorial on how to build and use a language model.
"""
__constants__ = ("vocab_size",)
vocab_size: int
def __init__(self, vocab_size: int):
vocab_size = argcheck.is_posi(vocab_size, "vocab_size")
super().__init__()
self.vocab_size = vocab_size
@torch.jit.export
def update_input(
self, prev: Dict[str, torch.Tensor], hist: torch.Tensor
) -> Dict[str, torch.Tensor]:
"""Update whatever is passed in as input to the language model
Parameters
----------
prev
The initial `prev` dictionary passed prior to calculating any log
probabilities.
hist
The initial `hist` tensor passed prior to calculating any log probabilites.
Returns
-------
prev_ : Dict[str, torch.Tensor]
The updated `prev`, populated with any additional information necessary to
calculating log probabilities.
Warnings
--------
This method should be robust to repeated calls prior to computing log
probabilities. That is, the result of ``update_input(prev, hist)`` should
be the same as ``update_input(update_input(prev, hist), hist)``.
"""
return prev
def extra_repr(self) -> str:
s = "vocab_size={}".format(self.vocab_size)
return s
@abc.abstractmethod
def calc_idx_log_probs(
self,
hist: torch.Tensor,
prev: Dict[str, torch.Tensor],
idx: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Calculates log_prob_idx over types at prefix up to and excluding idx
Implements the :func:`forward` call when `idx` is specified. See the class
description for more information on the parameters and returns. Note that `idx`
is guaranteed to be a tensor, either of shape ``(,)`` (scalar) or ``(N,)``, with
values in the range ``[0, hist.size(0)]``. `prev` can also be assumed to have
been initialized using :func:`update_input` when the index is zero.
Parameters
----------
hist
prev
idx
Returns
-------
log_probs_idx : torch.Tensor
next_ : Dict[str, torch.Tensor]
"""
raise NotImplementedError()
@torch.jit.export
def calc_full_log_probs(
self, hist: torch.Tensor, prev: Dict[str, torch.Tensor]
) -> torch.Tensor:
"""Calculates log_prob over all prefixes
Implements the :func:`forward` call when `idx` is not specified. See the class
description for more information on the parameters and returns. `prev` can be
assumed to have been initialized using :func:`update_input`.
Parameters
----------
hist
prev
Returns
-------
log_probs : torch.Tensor
"""
log_probs = []
for idx in torch.arange(hist.size(0) + 1, device=hist.device):
log_probs_idx, prev = self.calc_idx_log_probs(hist, prev, idx)
log_probs.append(log_probs_idx)
return torch.stack(log_probs, 0)
@overload
def forward(
self,
hist: torch.Tensor,
prev: Dict[str, torch.Tensor] = dict(),
*,
idx: Union[int, torch.Tensor],
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
...
@overload
def forward(
self,
hist: torch.Tensor,
prev: Dict[str, torch.Tensor] = dict(),
idx: Literal[None] = None,
) -> torch.Tensor:
...
@overload
def forward(
self,
hist: torch.Tensor,
prev: Dict[str, torch.Tensor] = dict(),
idx: Optional[Union[int, torch.Tensor]] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
...
def forward(
self,
hist: torch.Tensor,
prev: Optional[Dict[str, torch.Tensor]] = None,
idx: Optional[Any] = None,
) -> Any:
if prev is None:
prev = dict()
if hist.dim() != 2:
raise RuntimeError("hist must be 2 dimensional")
S, N = hist.shape
idx_ = torch.empty(0)
if idx is not None:
if isinstance(idx, int):
idx_ = torch.as_tensor(idx, dtype=torch.long, device=hist.device)
elif isinstance(idx, torch.Tensor):
idx_ = idx
if not idx_.numel():
raise RuntimeError("idx_ must be at least one element")
if idx_.dim() == 1:
if idx_.size(0) == 1:
idx_ = idx_.squeeze(0)
elif idx_.size(0) != N:
raise RuntimeError(
f"Expected dim 0 of idx_ to be of size {N}, got {idx_.size(0)}"
)
if ((idx_ < -S - 1) | (idx_ > S)).any():
raise RuntimeError(
f"All values in idx_ must be between ({-S - 1}, {S})"
)
idx_ = (idx_ + S + 1) % (S + 1)
prev = self.update_input(prev, hist)
if idx is None:
return self.calc_full_log_probs(hist, prev)
else:
return self.calc_idx_log_probs(hist, prev, idx_)
[docs]
class MixableSequentialLanguageModel(
ExtractableSequentialLanguageModel, metaclass=abc.ABCMeta
):
"""An ExtractableSequentialLanguageModel whose prev values can be mixed
In addition to the functionality of :class:`ExtractableSequentialLanguageModel`, a
:class:`MixableSequentialLanguageModel` can also account for transformations from
pairs of histories `hist_a` and `hist_b` into one `new_hist` such that each path in
the latter is either from `hist_a` or `hist_b`. :func:`mix_by_mask` accomplishes
this for the dictionaries `prev` and `in_next`.
"""
@abc.abstractmethod
def mix_by_mask(
self,
prev_true: Dict[str, torch.Tensor],
prev_false: Dict[str, torch.Tensor],
mask: torch.Tensor,
) -> Dict[str, torch.Tensor]:
"""Populate a new prev by picking values from either of two others
Assume we have three batched path history tensors `hist_true`, `hist_false`, and
`hist_new` each of shape ``(S, N)``. We're also assuming that if the sequences
in each are of different lengths, we've also padded them appropriately.
``hist_new[:, n] = hist_true[:, n]`` when ``mask[n] == True`` and ``hist_new[:,
n] = hist_false[:, n]`` otherwise. This method should apply the same
transformation between `prev_true` and `prev_false` to come up with `prev_new`.
Parameters
----------
prev_true
The input/output dictionary for the true branch of `mask`
prev_false
The input/output dictionary for the false branch of `mask`
mask
A boolean tensor of shape ``(N,)``
Returns
-------
prev_new : dict
Examples
--------
Continuing with the LSTM example from
:class:`ExtractableSequentialLanguageModel`, the hidden states and cell states
of the LSTM should always be the same size regardless of the remaining history,
making the implementation trivial:
>>> return {
... 'hidden_state': torch.where(
... mask.unsqueeze(1),
... prev_true['hidden_state'],
... prev_false['hidden_state']),
... 'cell_state': torch.where(
... mask.unsqueeze(1),
... prev_true['cell_state'],
... prev_false['cell_state']),
... }
"""
raise NotImplementedError()
@script
def _lookup_calc_idx_log_probs(
hist: torch.Tensor,
hidx: torch.Tensor,
offsets: torch.Tensor,
ids: torch.Tensor,
logps: torch.Tensor,
logbs: torch.Tensor,
sos: int,
V: int,
N: int,
G: int,
S: int,
) -> torch.Tensor:
# see commented description in LookupLanguageMode for more info on the structure of
# offsets, ids, logps, and logbs
#
# Follow two paths: one the full n-gram and the other its (n-1)-gram prefix -- the n
# and p paths respectively. If the p path completes, take the last logp. If the p
# path fails, take the last logp it didn't fail at and start accumulating logbs from
# the last on b path.
#
# Example: N = 4, N-gram = A B C D (note: unigrams always exist)
#
# match: D -x | v back: C -> B -> A = P(D)B(C)B(B, C)B(A, B, C)
#
# match: D -> C -x | v back: C -> B -> A = P(D|C)B(B, C)B(A, B, C)
#
# match: D -> C -> B -x | v back: C -> B -> A = P(D|B,C)B(A, B, C)
#
# match: D -> C -> B -> A
#
# back: C -> B -> A = P(D|A, B, C)
#
B: int = hist.size(1)
M, O = B * V, offsets.numel()
shift = 0 if (0 <= sos < V) else 1
U = V + shift + (1 % N)
I, P = O + G - U, O + G
device = hist.device
assert (ids.numel(), logps.numel(), logbs.numel()) == (I, P, O)
if hidx.numel() == 0:
raise RuntimeError("idx cannot be empty")
last_logps = logps[:V]
if N == 1:
# a unigram model doesn't rely on hist at all, so we bypass all the crap below
return last_logps.expand(B, V)
hidx_min = int(hidx.min().item())
rem = (N - 1) - hidx_min
if rem > 0:
# N.B. Some models require padding to the full context width with SOSes; others
# don't. In the latter case, padding should be harmless: the b path will always
# hit said padding before the p path, yielding backoffs of 0
hist = torch.cat(
[torch.full((rem, B), sos, dtype=torch.long, device=device), hist]
)
hidx, hidx_min, rem = hidx + rem, hidx_min + rem, 0
if hidx.numel() == 1:
# hidx_min is hidx
hist = hist[-rem:hidx_min]
else:
range_ = torch.arange(hist.size(0), device=device)
mask = (hidx.unsqueeze(1) - N < range_) & (hidx.unsqueeze(1) > range_)
hist = hist.T.masked_select(mask).view(B, N - 1).T
assert hist.shape == (N - 1, B), (N - 1, B, hist.shape)
if shift:
hist = hist.masked_fill(hist.eq(sos), V)
hist = hist.to(ids.dtype)
vrange = torch.arange(V + 1, device=device, dtype=torch.long)
hidx = torch.as_tensor(hidx, dtype=torch.long, device=device).expand(B)
srange = vrange[:S]
desc = torch.cat([vrange[:V].repeat(B), hist[-1].long()]) # (M + B,)
last_logps = last_logps.repeat(B) # (M,)
last_backoffs = logbs[desc[M:]].repeat_interleave(V) # (M,)
found = torch.ones(M + B, device=device, dtype=torch.bool)
for n in range(1, N):
hist_n = torch.cat([hist[-n].repeat_interleave(V), hist[-min(n + 1, N - 1)]])
desc_starts = offsets[desc].long() + desc # (M + B,)
desc_ends = offsets[desc + 1].long() + desc + 1 # (M + B,)
# there can't be more than S direct descendants per node
pos_desc = desc_starts.unsqueeze(1) + srange # (M + B, S)
extend_mask = desc_ends.unsqueeze(1) > pos_desc
ids_ = ids[pos_desc.clamp_max(P - 1) - U] # (M + B, S)
extend_mask = extend_mask & (hist_n.unsqueeze(1) == ids_)
found = extend_mask.any(1) & found # (M + B,)
desc = torch.where(found, pos_desc.masked_fill(~extend_mask, 0).sum(1), desc)
logps_desc = logps[desc[:M]]
if n == N - 1:
cur_backoffs = torch.zeros_like(last_backoffs)
else:
cur_backoffs = (
logbs[desc[M:].clamp_max(O - 1)]
.masked_fill(~found[M:], 0.0)
.repeat_interleave(V)
)
# Following Heafield's thesis, an infinite lprob indicates that this node is
# invalid, but it has children which could be valid. In this case, we treat the
# node as a backoff, but still treat the node as 'found'. That way, a later
# found node could overwrite the logp value.
clobber_logp = torch.isfinite(logps_desc) & found[:M]
cur_logps = torch.where(
clobber_logp, logps_desc, last_logps + cur_backoffs + last_backoffs
)
last_backoffs = cur_backoffs.masked_fill(~clobber_logp, 0.0)
last_logps = torch.where(
(hidx >= n).repeat_interleave(V), cur_logps, last_logps
)
return last_logps.view(B, V)
[docs]
class LookupLanguageModel(MixableSequentialLanguageModel):
r"""Construct a backoff n-gram model from a fixed lookup table
An instance of this model will search for a stored log-probability of the current
token given a fixed-length history in a lookup table. If it can't find it, it backs
off to a shorter length history and incurs a penalty:
.. math::
Pr(w_t|w_{t-1},\ldots,w_{t-(N-1)}) = \begin{cases}
Entry(w_{t-(N-1)}, w_{t-(N-1)+1}, \ldots, w_t)
& \text{if } Entry(w_{t-(N-1)}, \ldots) > 0 \\
Backoff(w_{t-(N-1)}, \ldots, w_{t-1}) Pr(w_t|w_{t-1},\ldots,w_{t-(N-1)+1}) &
\text{else}
\end{cases}
Missing entries are assumed to have value 0 and missing backoff penalties are
assumed to have value 1.
Parameters
----------
vocab_size sos
The start of sequence token. If specified, any prefix with fewer tokens than the
maximum order of n-grams minus 1 will be prepended up to that length with this
token.
prob_dicts
A list of dictionaries whose entry at index ``i`` corresponds to a table of
``i+1``-gram log-probabilities. Keys must all be ids, not strings. Unigram keys
are just ids; for n > 1 keys are tuples of ids with the latest word last. Values
in the dictionary of the highest order n-gram dictionaries (last in
`prob_dicts`) are the log-probabilities of the keys. Lower order dictionaries'
values are pairs of log-probability and log-backoff penalty. If `prob_dicts` is
not specified, a unigram model with a uniform prior will be built.
destructive
If :obj:`True`, allows initialization to modify `prob_dicts` directly instead of
making a fresh copy. Doing so can help reduce memory pressure.
logger
If specified, this logger will be used to report on the progress initializing
this module.
Warnings
--------
This class differs considerably from its `0.3.0` version. `prob_list` was renamed to
`prob_dicts`; `prob_list` is deprecated. `sos` became no longer optional.
`pad_sos_to_n` was removed as an argument (implicitly true now). `eos` and `oov`
were also removed as part of updates to :obj:`SequentialLanguageModel`. Finally, the
underlying buffers of this model have changed in structure and name, invalidating
any old saved state dictionaries.
JIT scripting is possible with this module, but not tracing.
Notes
-----
Initializing an instance from an `prob_dicts` is expensive. `prob_dicts` is
converted to a reverse trie (something like [heafield2011]_) so that it takes up
less space in memory, which can take some time.
Rather than re-initializing repeatedly, it is recommended you save and load this
module's state dict. :func:`load_state_dict` as been overridden to support loading
different table sizes, avoiding the need for an accurate `prob_dicts` on
initialization:
>>> # first time
>>> lm = LookupLanguageModel(vocab_size, sos, prob_dicts) # slow
>>> state_dict = lm.state_dict()
>>> # save state dict, quit, startup, then reload state dict
>>> lm = LookupLanguageModel(vocab_size, sos) # fast!
>>> lm.load_state_dict(state_dict)
See Also
--------
SequentialLanguageModel
A general description of language models, including call parameters
pydrobert.util.parse_arpa_lm
How to read a pretrained table of n-gram probabilities into `prob_dicts`. The
parameter `token2id` should be specified to ensure id-based keys.
"""
__constants__ = (
"vocab_size",
"sos",
"max_ngram",
"max_ngram_nodes",
"max_direct_descendants",
)
sos: int
max_ngram: int
max_ngram_nodes: int
max_direct_descendants: int
# we follow [heafield2011] and earlier systems by constructing a reverse trie. E.g.
# if we have 3-grams for ('A', 'B', 'C'), ('A', 'B', 'D'), and ('B', 'B', 'C'), then
# part of our trie will be something like
#
# root -> {C -> B -> {A, B}}, {D -> B -> A}
#
# Unigram probabilities and unigram backoffs for "C" and "D" will be stored with the
# direct descendants of the root, probs/backoffs for the bigrams ('B', 'C') and
# ('B', 'D') with the level-2 nodes, and the trigram probs/backoffs with level 3.
#
# The label, log probability, and backoffs (for lower-order nodes) are stored in
# flat buffers "ids", "logps", and "logbs" respectively, with values accessible by
# index. Indices satisfy the following invariants:
#
# 1. The direct descendants of the root are the entire vocabulary, sorted by id
# 2. All nodes of level n occur before (in index) those of level n + 1
# 3. Direct descendants of a node are sorted by id
# 4. If two nodes i and j are on the same level and i < j, then all the direct
# descendants of i (and thus all descendants of i) occur before those of j
#
# Unigram nodes always occupy the first `vocab_size` indices. When the maximal order
# of the trie is > 1, the trie may be navigated with the buffer `offsets`. The value
# ``offsets[idx]`` contains the offset from the node indexed at `idx` to the first
# of its direct descendants, inclusive, i.e. ``child_idx = idx + offsets[idx]``.
# Since its sibling's descendants start at ``idx + 1 + offsets[idx
# + 1]``, that is also the exclusive upper bound of the direct descendants of `idx`.
# To make computations convenient, the final indexed node of each level is a dummy
# node which points to the dummy node in the subsequent level, ensuring
# ``offsets[idx + 1]`` exists for all real internal nodes.
#
# Any non-existent prefix of an n-gram will be assigned 0 probability, ensuring the
# number of unique prefixes of order n always matches the number of n-grams. offsets
# thus has the structure:
#
# offsets = [1-gram offsets + dummy] + [2-gram offsets + dummy] + ...
# [(max_order-1)-gram offsets + dummy] with a
# combined length of
#
# (# offsets) = max_order - 1 + sum_n^{(N-1)} (# n-grams)
#
# We don't need offsets for the max-order n-grams because they are leaves.
#
# "logb" has the same structure, but contains backoff log probabilities instead of
# offsets:
#
# logb = [1-gram backoffs + dummy] + [2-gram backoffs + dummy] + ...
# [(max_order-1)-gram backoffs + dummy]
#
# "logp" contains the log probabilities, including the maximal order:
#
# logp = [1-gram log-probs + dummy] + [2-gram log-probs + dummy] + ...
# [max_order-gram log-probs]
#
# with combined length of
#
# (# logps) = (# offsets) + (# max_order n-grams)
#
# no need to keep the final dummy.
#
# Finally, ids contains the labels of the nodes. It has structure
#
# ids = [2-gram ids + dummy] + [3-gram ids + dummy] + ... [max_order-gram ids]
#
# with combined length of
#
# (# ids) = (# logps) - (# 1-grams) - 1
#
# 1-gram ids can be easily inferred by invariant 1. This implies ``ids[idx]`` is
# actually the label of the node at ``idx + (vocab_size + 1)`` otherwise.
@overload
def __init__(
self,
vocab_size: int,
sos: int,
prob_dicts: Optional[ProbDicts] = None,
destructive: bool = False,
logger: Optional[Logger] = None,
):
...
def __init__(
self,
vocab_size: int,
sos: int,
prob_dicts: Optional[ProbDicts] = None,
destructive: bool = False,
logger: Optional[Logger] = None,
*,
prob_list: Optional[ProbDicts] = None,
):
sos = argcheck.is_int(sos, "sos")
destructive = argcheck.is_bool(destructive)
if prob_list is not None:
if prob_dicts is None:
warnings.warn(
"prob_list has been renamed to prob_dicts", DeprecationWarning
)
prob_dicts = prob_list
else:
raise ValueError(
"prob_list and prob_dicts cannot be specified simultaneously"
)
super().__init__(vocab_size)
self.sos = sos
if prob_dicts is None:
if logger is not None:
logger.info("prob_dicts is empty; initializing uniform model")
logps = -torch.full(
(self.shift + vocab_size,), vocab_size, dtype=torch.float
).log()
logbs = torch.tensor([], dtype=torch.float)
ids = offsets = torch.tensor([], dtype=torch.uint8)
self.max_ngram = 1
self.max_direct_descendants = 0
self.max_ngram_nodes = self.shift + vocab_size
else:
self.max_ngram = len(prob_dicts)
self.max_ngram_nodes = -1 # changed by build_trie
logps, logbs, ids, offsets = self._build_trie(
prob_dicts, destructive, logger
)
self.max_direct_descendants = self._infer_max_direct_descendants(offsets)
self.register_buffer("logps", logps)
self.register_buffer("logbs", logbs)
self.register_buffer("ids", ids)
self.register_buffer("offsets", offsets)
def extra_repr(self) -> str:
s = super(LookupLanguageModel, self).extra_repr()
s += ", max_ngram={}, sos={}".format(self.max_ngram, self.sos)
return s
@torch.jit.export
def extract_by_src(
self, prev: Dict[str, torch.Tensor], src: torch.Tensor
) -> Dict[str, torch.Tensor]:
return dict()
@torch.jit.export
def mix_by_mask(
self,
prev_true: Dict[str, torch.Tensor],
prev_false: Dict[str, torch.Tensor],
mask: torch.Tensor,
) -> Dict[str, torch.Tensor]:
return dict()
@property
def shift(self) -> int:
return 0 if (0 <= self.sos < self.vocab_size) else 1
def calc_idx_log_probs(
self,
hist: torch.Tensor,
prev: Dict[str, torch.Tensor],
idx: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
return (
_lookup_calc_idx_log_probs(
hist,
idx,
self.offsets,
self.ids,
self.logps,
self.logbs,
self.sos,
self.vocab_size,
self.max_ngram,
self.max_ngram_nodes,
self.max_direct_descendants,
),
prev,
)
@torch.jit.export
def calc_full_log_probs(
self,
hist: torch.Tensor,
prev: Dict[str, torch.Tensor],
) -> torch.Tensor:
return self.calc_full_log_probs_chunked(hist, prev, 1)
@torch.jit.export
def calc_full_log_probs_chunked(
self,
hist: torch.Tensor,
prev: Dict[str, torch.Tensor],
chunk_size: int = 1,
) -> torch.Tensor:
"""Computes full log probabilities in chunks
This method has the same interpretation and return value as
:func:`calc_full_log_probs`, but with an additional optional argument
`chunk_size` to control the number of distributions over tokens to compute
simultaneously.
Because the distribution over the current token does not depend on any prior
state, it is possible to compute all token distributions simultaneously. While
faster, it is also much more memory-intensive to do so (especially so for large
vocabularies). `chunk_size` provides a lever for this trade-off. Note that
the computation of token distributions is always parallelized across the
batch dimension.
Parameters
----------
hist
prev
chunk_size
Returns
-------
log_probs : torch.Tensor
"""
T, B = hist.shape
N, V = self.max_ngram, self.vocab_size
Nm1, device = min(T, N - 1), hist.device
hist = hist.contiguous()
hist_ = idx_ = log_probs_ = hist # for torchscript
if chunk_size < 1:
raise RuntimeError(f"expected chunk_size to be positive; got {chunk_size}")
log_probs = [torch.empty(0, B, V, device=device)]
for idx_ in torch.arange(Nm1, device=hist.device):
log_probs_ = self.calc_idx_log_probs(hist[:idx_], prev, idx_)[0]
log_probs.append(log_probs_.unsqueeze(0))
if Nm1 < T + 1:
idx_ = torch.tensor(Nm1, dtype=torch.long, device=device)
for t in range(Nm1, T + 1, chunk_size):
T_rest = min(chunk_size, T + 1 - t)
hist_ = hist.as_strided((Nm1, T_rest * B), (B, 1), B * (t - Nm1))
log_probs_ = self.calc_idx_log_probs(hist_, prev, idx_)[0]
log_probs_ = log_probs_.view(T_rest, B, V)
log_probs.append(log_probs_)
log_probs_ = torch.cat(log_probs, 0)
assert log_probs_.size(0) == T + 1, (log_probs_.size(0), T + 1)
return log_probs_
def load_state_dict(self, state_dict: dict, **kwargs) -> None:
error_prefix = "Error(s) in loading state_dict for {}:\n".format(
self.__class__.__name__
)
missing_keys = {"offsets", "ids", "logps", "logbs"} - set(state_dict)
if missing_keys:
raise RuntimeError(
'Missing key(s) in state_dict: "{}".'.format('", "'.join(missing_keys))
)
offsets, ids = state_dict["offsets"], state_dict["ids"]
logps, logbs = state_dict["logps"], state_dict["logbs"]
if ids.numel() and offsets.numel():
# n > 1
U = self.vocab_size + self.shift + 1
if len(offsets) < U:
raise RuntimeError(
error_prefix + "Expected {} unigram probabilities, got {} "
"(vocab_size and sos must be correct!)".format(
U - 1, len(offsets) - 1
)
)
O, I, P = len(offsets), len(ids), len(logps)
self.max_ngram = 1
self.max_ngram_nodes = last_ptr = U - 1
error = RuntimeError(
error_prefix + "buffer contains unexpected value (are you sure "
"you've set vocab_size and sos correctly?)"
)
while last_ptr < len(offsets):
offset = offsets[last_ptr].item()
if offset <= 0:
raise error
last_ptr += offset
self.max_ngram_nodes = offset - 1
self.max_ngram += 1
if last_ptr != O + self.max_ngram_nodes:
raise RuntimeError(error_prefix + "Unexpected buffer length")
else: # n == 1
if len(offsets) != len(ids):
raise RuntimeError(error_prefix + "Incompatible trie buffers")
if len(logps) != self.vocab_size + self.shift:
raise RuntimeError(
error_prefix + "Expected {} unigram probabilities, got {} "
"(vocab_size and sos must be correct!)"
"".format(self.vocab_size + self.shift, len(logps))
)
self.max_ngram_nodes = self.vocab_size + self.shift
self.max_ngram = 1
self.max_direct_descendants = self._infer_max_direct_descendants(offsets)
# resize
self.offsets = torch.empty_like(offsets, device=self.offsets.device)
self.ids = torch.empty_like(ids, device=self.ids.device)
self.logps = torch.empty_like(logps, device=self.logps.device)
self.logbs = torch.empty_like(logbs, device=self.logbs.device)
return super(LookupLanguageModel, self).load_state_dict(state_dict, **kwargs)
@torch.jit.unused
def _build_trie(
self, prob_dicts: ProbDicts, destructive: bool, logger: Optional[Logger]
):
if logger is None:
print_ = lambda x: None
else:
print_ = logger.info
if not len(prob_dicts):
raise ValueError("prob_dicts must contain at least unigrams")
if not destructive:
print_("destructive not passed; copying prob_dicts")
prob_dicts = [prob_dict.copy() for prob_dict in prob_dicts]
total_entries, nan, inf = 0, float("nan"), float("inf")
unigrams = set(range(self.vocab_size))
if self.shift:
unigrams.add(self.sos)
for n in range(self.max_ngram - 1, -1, -1):
print_(f"checking prob_dict of order {n + 1}")
prob_dict = prob_dicts[n]
is_last = n == self.max_ngram - 1
if is_last and not prob_dict:
raise ValueError("Final element in prob_dicts must not be empty")
if not n:
keys = set(prob_dict.keys())
if keys - unigrams:
raise ValueError(
"Unexpected unigrams in prob_dicts: {} (are these "
"ids?)".format(keys - unigrams)
)
if is_last:
dummy_value = -inf
else:
dummy_value = -inf, 0.0
prob_dict.update((key, dummy_value) for key in unigrams - keys)
else:
for seq in prob_dict:
if len(seq) != n + 1:
raise ValueError(
"Key {0} in {1}-gram is not a sequence of length "
"{1}".format(n + 1, seq)
)
if set(seq) - unigrams:
raise ValueError(
"Unexpected tokens in {}-gram in prob_dicts: {} ("
"are these ids?)"
"".format(n + 1, set(seq) - unigrams)
)
suffix = seq[1:]
if len(suffix) == 1:
suffix = suffix[0]
if suffix not in prob_dicts[n - 1]:
print_(
f"{suffix} is not an entry in order {n} prob_dict but is a "
f"suffix of {seq}. Adding (-inf, 0.0)"
)
prob_dicts[n - 1][suffix] = -inf, 0.0
total_entries += len(prob_dict)
if is_last:
self.max_ngram_nodes = len(prob_dict)
if self.shift:
print_(f"mapping sos={self.sos} -> {self.vocab_size}")
prob_dicts[0][self.vocab_size] = prob_dicts[0].pop(self.sos)
for n in range(1, self.max_ngram):
sos_keys = []
prob_dict = prob_dicts[n]
for key in prob_dict.keys():
if self.sos in key:
sos_keys.append(key)
while len(sos_keys):
key = sos_keys.pop()
key_ = tuple(self.vocab_size if k == self.sos else k for k in key)
prob_dict[key_] = prob_dict.pop(key)
N, G, V = self.max_ngram, self.max_ngram_nodes, self.vocab_size
# U = # unigrams + dummy, O = # offsets or # logbs
# I = # ids, P = # logps
U, O = V + self.shift + (1 % N), total_entries - G + (N - 1)
I, P = O + G - U, O + G
if N > 1:
# what's the maximum possible offset? It's the maximal possible distance
# between a parent and child, or an n-gram and an (n-1)-gram. Let the former
# have S nodes in the level, the latter T nodes. Let a, b, and c correspond
# to offsets of distinct paths through the trie and x be the dummy offset.
# The longest offset in offsets is produced as a value of b like this:
#
# abcccc...cxaaaa...bx
#
# i.e. there are a lot of branches of a in (n+1) but only one parent, and
# there are a lot of branches of c in n but no descendants. The hop from b
# to x is of size S - 1, and the hop from x to the next b is of size T, so
# the worst potential hop is S + T - 1
max_potential_offset = max(
len(prob_dicts[n]) + len(prob_dicts[n - 1]) - 1 for n in range(1, N)
)
else:
max_potential_offset = 0 # no descendants
offset_type = offset_type_ = int # for type checker
offset_imax = float("inf")
for offset_type, offset_type_ in (
(torch.uint8, np.uint8),
(torch.int16, np.int16),
(torch.int32, np.int32),
(torch.int64, np.int64),
):
offset_imax = torch.iinfo(offset_type).max
if offset_imax >= max_potential_offset:
break
if torch.iinfo(offset_type).max < max_potential_offset:
# should not happen
raise ValueError("too many childen")
for id_type in (torch.uint8, torch.int16, torch.int32, torch.int64):
if torch.iinfo(id_type).max >= U:
break
if torch.iinfo(id_type).max < U:
# should never happen in a practical situation
raise ValueError("vocab too large")
print_("Allocating 1-grams")
offsets = torch.zeros(O, dtype=offset_type)
ids = torch.zeros(I, dtype=id_type)
logps = torch.zeros(P, dtype=torch.float)
logbs = torch.zeros(O, dtype=torch.float)
prob_dict = prob_dicts.pop(0)
unigram_values = [prob_dict[x] for x in range(U - 1 % N)]
last_start, allocated = 0, U - 1 % N
if N == 1:
logps.copy_(torch.tensor(unigram_values, dtype=torch.float))
else:
logps[:allocated].copy_(
torch.tensor([x[0] for x in unigram_values], dtype=torch.float)
)
logbs[:allocated].copy_(
torch.tensor([x[1] for x in unigram_values], dtype=torch.float)
)
del unigram_values
parents = dict(((x,), offset_type_(x)) for x in range(U - 1))
N = 2
while prob_dicts:
prob_dict = prob_dicts.pop(0)
start = allocated
offsets[allocated] = len(prob_dict) + 1
logps[allocated] = logbs[allocated] = nan
allocated += 1
children = dict()
print_(f"Sorting {N}-grams")
prob_list = SortedList()
while prob_dict:
key, value = prob_dict.popitem()
insort_left(prob_list, (key[::-1], value))
print_(f"Allocating {N}-grams")
while prob_list:
key, value = prob_list.pop(0)
assert 0 <= (allocated - start) <= offset_imax
children[key] = offset_type_(allocated - start)
ids[allocated - U] = int(key[-1])
if prob_dicts:
logps[allocated] = float(value[0])
logbs[allocated] = float(value[1])
else:
logps[allocated] = float(value)
prefix = key[:-1]
parent = parents[prefix] + last_start
while parent >= 0 and not offsets[parent]:
offsets[parent] = allocated - parent
parent -= 1
allocated += 1
for i in range(start, -1, -1):
if offsets[i - 1]:
break
offsets[i - 1] = offsets[i] + 1
parents.clear()
parents, last_start = children, start
N += 1
# see if we can shrink the offset size
if len(offsets):
max_offset = offsets.max().item()
for offset_type in (torch.uint8, torch.int16, torch.int32, torch.int64):
if torch.iinfo(offset_type).max >= max_offset:
break
print_(f"Updating offset dtype to {offset_type}")
offsets = offsets.to(offset_type)
return logps, logbs, ids, offsets
@torch.jit.unused
def _infer_max_direct_descendants(
self, offsets: Optional[torch.Tensor] = None
) -> int:
# excluding the root
if offsets is None:
offsets = self.offsets
O = offsets.numel()
if not O:
return 0
U = i = self.vocab_size + (0 if (0 <= self.sos < self.vocab_size) else 1) + 1
assert 0 < i <= O
S = (offsets[1:i] + 1 - offsets[: i - 1]).max()
assert S >= 0
while i < O:
j = i + int(offsets[i])
S = torch.max(S, (offsets[i + 1 : j] + 1 - offsets[i : j - 1]).max())
i = j
assert S < U, (S, U)
return int(S.item())
__call__ = proxy(SequentialLanguageModel.forward)
[docs]
class ShallowFusionLanguageModel(SequentialLanguageModel):
r"""Language model combining two language models with shallow fusion
Shallow fusion [gulcehre2015]_ combines the predictions of two language models
by taking the weighted sum of their log probabilities:
.. math::
\log S(y_t=v|...) = \log P_{first}(y_t=v|...) +
\beta \log P_{second}(y_t = v|...)
The resulting value :math:`log S(y_t=v)` is not technically a probability.
Parameters
----------
first
The first language model
second
The second language model, whose log probabilities multiply with `beta`
beta
The value :math:`\beta`
first_prefix
Elements of the state dict for `first` will have `first_prefix` prepended to
their keys
second_prefix
Like `first_prefix`, but for `second`
Warnings
--------
This class does not (and cannot) support JIT.
Notes
-----
If you intend to perform shallow fusion between CTC logits and an external language
model, you will not be able to do so via this class. CTC operates on an extended
vocabulary while an external language model does not. Fortunately,
:class:`CTCPrefixSearch` has built-in support for shallow fusion. Consult that
class for more information.
See Also
--------
MixableShallowFusionModel
A mixable subclass of this class. Applicable only if `first` and `second`
are both :class:`MixableSequentialLanguageModel` instances.
ExtractableShallowFusionModel
An extractable subclass of this class. Applicable only if `first` and `second`
are both :class:`ExtractableSequentialLanguageModel` instances.
"""
__constants__ = "beta", "first_prefix", "second_prefix"
first: SequentialLanguageModel
second: SequentialLanguageModel
beta: float
first_prefix: str
second_prefix: str
def __init__(
self,
first: SequentialLanguageModel,
second: SequentialLanguageModel,
beta: float = 0.0,
first_prefix: str = "first.",
second_prefix: str = "second.",
):
beta = argcheck.is_float(beta, "beta")
first_prefix = argcheck.is_str(first_prefix, "first_prefix")
second_prefix = argcheck.is_str(second_prefix, "second_prefix")
if first.vocab_size != second.vocab_size:
raise ValueError(
f"first's vocab_size ({first.vocab_size}) differs from second's "
f"vocab_size ({second.vocab_size})"
)
if not len(first_prefix) or not len(second_prefix):
raise ValueError(f"prefixes cannot be empty")
if first_prefix == second_prefix:
raise ValueError(f"first_prefix matches second_prefix ('{first_prefix}')")
super().__init__(first.vocab_size)
self.first, self.second, self.beta = first, second, beta
self.first_prefix, self.second_prefix = first_prefix, second_prefix
def extra_repr(self) -> str:
return super().extra_repr() + (
f", beta={self.beta}, first_prefix='{self.first_prefix}'"
f", second_prefix='{self.second_prefix}'"
f", first={self.first}, second={self.second}"
)
@torch.jit.export
def split_dicts(
self, prev: Dict[str, torch.Tensor]
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
"""Split state dicts into state dicts for first and second lms"""
prev_first: Dict[str, torch.Tensor] = dict()
prev_second: Dict[str, torch.Tensor] = dict()
for k, v in prev.items():
if k.startswith(self.first_prefix):
prev_first[k[len(self.first_prefix) :]] = v
elif k.startswith(self.second_prefix):
prev_second[k[len(self.second_prefix) :]] = v
else:
raise RuntimeError(
f"key '{k}' from prev does not start with first_prefix "
f"'{self.first_prefix}' nor second_prefix '{self.second_prefix}'"
)
return prev_first, prev_second
@torch.jit.export
def merge_dicts(
self, prev_first: Dict[str, torch.Tensor], prev_second: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""Merge state dicts from first and second lms into state dict"""
prev: Dict[str, torch.Tensor] = dict()
prev.update((self.first_prefix + k, v) for (k, v) in prev_first.items())
prev.update((self.second_prefix + k, v) for (k, v) in prev_second.items())
return prev
@torch.jit.export
def update_input(
self, prev: Dict[str, torch.Tensor], hist: torch.Tensor
) -> Dict[str, torch.Tensor]:
prev_first, prev_second = self.split_dicts(prev)
prev_first = self.first.update_input(prev_first, hist)
prev_second = self.second.update_input(prev_second, hist)
return self.merge_dicts(prev_first, prev_second)
@torch.jit.export
def calc_idx_log_probs(
self, hist: torch.Tensor, prev: Dict[str, torch.Tensor], idx: torch.Tensor
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
prev_first, prev_second = self.split_dicts(prev)
log_probs_first, cur_first = self.first.calc_idx_log_probs(
hist, prev_first, idx
)
log_probs_second, cur_second = self.second.calc_idx_log_probs(
hist, prev_second, idx
)
log_probs = log_probs_first + self.beta * log_probs_second
cur = self.merge_dicts(cur_first, cur_second)
return log_probs, cur
@torch.jit.export
def calc_full_log_probs(
self, hist: torch.Tensor, prev: Dict[str, torch.Tensor]
) -> torch.Tensor:
prev_first, prev_second = self.split_dicts(prev)
log_probs_first = self.first.calc_full_log_probs(hist, prev_first)
log_probs_second = self.second.calc_full_log_probs(hist, prev_second)
return log_probs_first + self.beta * log_probs_second
[docs]
class MixableShallowFusionLanguageModel(
ExtractableShallowFusionLanguageModel, MixableSequentialLanguageModel
):
"""ShallowFusionLanguageModel which is also a MixableSequentialLanguageModel
Both `first` and `second` must be :class:`ExtractableSequentialLanguageModel`
instances.
See Also
--------
ShallowFusionLanguageModel
For a description of shallow fusion and parameters. `first` and `second` may
not be mixable, but neither is :class:`ShallowFusionLanguageModel`.
ExtractableSequentialLanguageModel
An extractable superclass of this class. Applicable if `first` and `second`
are both :class:`ExtractableSequentialLanguageModel` instances.
"""
first: MixableSequentialLanguageModel
second: MixableSequentialLanguageModel
def __init__(
self,
first: MixableSequentialLanguageModel,
second: MixableSequentialLanguageModel,
beta: float = 0,
first_prefix: str = "first.",
second_prefix: str = "second.",
):
super().__init__(first, second, beta, first_prefix, second_prefix)
@torch.jit.export
def mix_by_mask(
self,
prev_true: Dict[str, torch.Tensor],
prev_false: Dict[str, torch.Tensor],
mask: torch.Tensor,
) -> Dict[str, torch.Tensor]:
prev_first_true, prev_second_true = self.split_dicts(prev_true)
prev_first_false, prev_second_false = self.split_dicts(prev_false)
prev_first = self.first.mix_by_mask(prev_first_true, prev_first_false, mask)
prev_second = self.second.mix_by_mask(prev_second_true, prev_second_false, mask)
return self.merge_dicts(prev_first, prev_second)