# Copyright 2022 Sean Robertson
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Tuple, Union, overload
import torch
from torch.distributions import constraints
from torch.distributions.utils import lazy_property
from . import argcheck
from ._compat import script, trunc_divide
[docs]
@script
def simple_random_sampling_without_replacement(
total_count: torch.Tensor,
given_count: torch.Tensor,
out_size: Optional[int] = None,
) -> torch.Tensor:
"""Draw a binary vector with uniform probabilities but fixed cardinality
Uses the algorithm proposed in [fan1962]_.
Parameters
----------
total_count
The nonnegative sizes of the individual binary vectors. Must broadcast with
`given_count`.
given_count
The cardinalities of the individual binary vectors. Must broadcast with
and not exceed the values of `total_count`.
out_size
The vector size. Must be at least the value of ``total_count.max()``. If unset,
will default to that value.
Returns
-------
b : torch.Tensor
A sample tensor of shape ``(*, out_size)``, where ``(*,)`` is the broadcasted
shape of `total_count` and `given_count`. The final dimension is the vector
dimension. The ``n``-th vector of `b` is right-padded with zero for all values
exceeding ``total_count[n]``, i.e. ``b[n, total_count[n]:].sum() == 0``. The
remaining ``total_count[n]`` elements of the vector sum to associated given
count, i.e. ``b[n, :total_count[n]].sum() == given_count[n]``.
See Also
--------
pydrobert.torch.distributions.SimpleRandomSamplingWithoutReplacement
For information on the distribution.
"""
total_count_max = int(total_count.max().item())
if out_size is None:
out_size = total_count_max
total_count, given_count = torch.broadcast_tensors(total_count, given_count)
if (given_count > total_count).any():
raise RuntimeError("given_count cannot exceed total_count")
if out_size < total_count_max:
raise RuntimeError(
f"out_size ({out_size}) must not be less than max of total_count "
f"({total_count_max})"
)
b = torch.empty(
torch.Size([out_size]) + total_count.shape, device=total_count.device
)
remainder_ell = given_count
remainder_t = total_count.clamp_min(1)
for t in range(out_size):
p = remainder_ell / remainder_t
b_t = torch.bernoulli(p)
b[t] = b_t
remainder_ell = remainder_ell - b_t
remainder_t = (remainder_t - 1).clamp_min_(1)
return b.view(out_size, -1).T.view(total_count.shape + torch.Size([out_size]))
[docs]
class BinaryCardinalityConstraint(constraints.Constraint):
"""Ensures a vector of binary values sums to the required cardinality"""
is_discrete = True
event_dim = 1
def __init__(
self,
given_count: torch.Tensor,
tmax: int,
total_count: Optional[torch.Tensor] = None,
) -> None:
tmax = argcheck.is_nat(tmax, "tmax")
given_count = argcheck.is_nonnegt(given_count, "given_count")
if total_count is None:
total_count_mask = torch.zeros(
1, dtype=torch.bool, device=given_count.device
)
else:
total_count = argcheck.is_nonnegt(total_count, "total_count")
total_count_mask = total_count.unsqueeze(-1) <= torch.arange(
tmax, device=total_count.device
)
super().__init__()
self.given_count, self.total_count_mask = given_count, total_count_mask
def check(self, value: torch.Tensor) -> torch.Tensor:
is_bool = ((value == 0) | (value == 1)).all(-1)
isnt_gte_tc = (self.total_count_mask.expand_as(value) * value).sum(-1) == 0
value_sum = value.sum(-1)
matches_count = value_sum == self.given_count.expand_as(value_sum)
return is_bool & isnt_gte_tc & matches_count
[docs]
@script
def binomial_coefficient(length: torch.Tensor, count: torch.Tensor) -> torch.Tensor:
r"""Compute the binomial coefficients (length choose count)
The binomial coefficient "`length` choose `count`" is calculated as
.. math::
\binom{length}{count} = \frac{length!}{length!(length - count)!} \\
x! = \begin{cases}
\prod_{x'=1}^x x' & x > 0 \\
1 & x = 0 \\
0 & x < 0
\end{cases}
Parameters
----------
length
A long tensor of the upper terms in the coefficient. Must broadcast with
`count`.
count
A long tensor of the lower terms in the coefficient. Must broadcast with
`length`.
Returns
-------
binom : torch.Tensor
A long tensor of the broadcasted shape of `length` and `count`. The value
at multi-index ``n``, ``binom[n]``, stores the binomial coefficient
``length[n]`` choose ``count[n]``, assuming `length` and `count` have already
been broadcast together.
Warnings
--------
As the values in `binom` can get very large, this function is susceptible to
overflow. For example, :math:`\binom{67}{33}` exceeds the long's maximum. Overflow
will be avoided by ensuring `length` does not exceed :obj:`66`. The binomial
coefficient is at its highest when ``count = length // 2`` and at its lowest when
``count == length`` or ``count == 0``.
Notes
-----
When the maximum `length` exceeds :obj:`20`, the implementation uses the recursion
defined in [howard1972]_.
"""
device = length.device
if ((count < 0) | (length < 0)).any():
raise RuntimeError("length and count must be non-negative")
length_ = int(length.max().item())
if length_ > 20:
count_ = int(count.max().item())
binom = torch.empty((count_ + 1, length_ + 1), device=device, dtype=torch.long)
binom[..., 0] = 0
binom[0] = 1
for c in range(1, count_ + 1):
binom[c, 1:] = binom[c - 1, :-1].cumsum(0)
binom = binom.flatten()[length + count * (length_ + 1)]
else:
# the factorials are guaranteed to lie within long precision; this algorithm
# saves some time
length_m_count = (length - count).clamp_min_(-1)
count = count.clamp_max(length_)
x = torch.arange(length_ + 2, device=device)
x[0] = 1
x = x.cumprod(0)
binom = trunc_divide(x[length], x[count] * x[length_m_count])
binom.masked_fill_(length_m_count == -1, 0)
return binom
# why bother with overloads? JIT. torch.dtype is technically an int right now. Later
# versions of pytorch are able to quietly convert the type, but not 1.5.1
# https://github.com/pytorch/pytorch/issues/65607
@overload
def enumerate_vocab_sequences(
length: int,
vocab_size: int,
device: torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.long,
) -> torch.Tensor:
...
[docs]
@script
def enumerate_vocab_sequences(
length: int,
vocab_size: int,
device: torch.device = torch.device("cpu"),
dtype: int = torch.long,
) -> torch.Tensor:
"""Enumerate all sequences of a finite range of values of a fixed length
This function generalizes :func:`enumerate_binary_sequences` to any positive
vocabulary size. Each step in each sequence takes on a value from 0-`vocab_size`
Parameters
----------
length
The non-negative length of the vocab sequence.
vocab_size
The positive number of values in the vocabulary.
device
What device to return the tensor on.
dtype
The data type of the returned tensor.
Returns
-------
support : torch.Tensor
A tensor of shape ``(vocab_size ** length, length)`` of all possible sequences
with that vocabulary. The sequences are ordered such that all configurations
where ``support[s, t] > 0`` must follow those where ``support[s', t] == 0``
(i.e. it implies ``s' < s``). Therefore all sequences of length ``length
- x`` are contained in ``support[2 ** (length - x), :length - x]``.
"""
if length < 0:
raise RuntimeError(f"length must be non-negative, got {length}")
if vocab_size <= 0:
raise RuntimeError(f"vocab_size must be positive, got {vocab_size}")
support = torch.empty(
(length, int(vocab_size ** length)), device=device, dtype=dtype
)
range_ = torch.arange(vocab_size, device=device, dtype=dtype).view(1, vocab_size, 1)
for t in range(length):
support.view(length, int(vocab_size ** t), vocab_size, -1)[
length - t - 1
] = range_
return support.T.contiguous()
@overload
def enumerate_binary_sequences(
length: int,
device: torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.long,
) -> torch.Tensor:
...
[docs]
def enumerate_binary_sequences(
length: int, device: torch.device = torch.device("cpu"), dtype: int = torch.long,
) -> torch.Tensor:
"""Enumerate all binary sequences of a fixed length
Parameters
----------
length
The non-negative length of the binary sequences.
device
What device to return the tensor on.
dtype
The data type of the returned tensor.
Returns
-------
support : torch.Tensor
A tensor of shape ``(2 ** length, length)`` of all possible binary sequences of
length `length`. The sequences are ordered such that all configurations where
``support[s, t] == 1`` must follow those where ``support[s', t] == 0`` (i.e. it
implies ``s' < s``). Therefore all binary sequences of length ``length - x`` are
contained in ``support[2 ** (length - x), :length - x]``.
Examples
--------
>>> support = enumerate_binary_sequences(3)
>>> print(support)
tensor([[0, 0, 0],
[1, 0, 0],
[0, 1, 0],
[1, 1, 0],
[0, 0, 1],
[1, 0, 1],
[0, 1, 1],
[1, 1, 1]])
>>> print(support[:4, :2])
tensor([[0, 0],
[1, 0],
[0, 1],
[1, 1]])
"""
return enumerate_vocab_sequences(length, 2, device, dtype)
@overload
def enumerate_binary_sequences_with_cardinality(
length: int,
count: int,
device: torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.long,
) -> torch.Tensor:
...
@overload
def enumerate_binary_sequences_with_cardinality(
length: torch.Tensor, count: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
...
@script
def _enumerate_binary_sequences_with_cardinality_int(
length: int, count: int, device: torch.device, dtype: int
) -> torch.Tensor:
support = enumerate_binary_sequences(length, device, dtype)
support = support[support.sum(1) == count]
return support
@script
def _enumerate_binary_sequences_with_cardinality_tensor(
length: torch.Tensor, count: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
device = length.device
length_ = int(length.max().item())
length, count = torch.broadcast_tensors(length, count)
binom = binomial_coefficient(length, count)
binom_ = int(binom.max().item())
# _enumerate_binary_sequences outputs sequences with b_t = 1 only after all
# sequences with b_t = 0. We therefore capture all the combos for a given length
# by limiting ourselves to the indices up to 2 ** length.
N = int(2 ** length_)
support = enumerate_binary_sequences(length_, device, length.dtype)
support = torch.cat([support, torch.empty_like(support)])
range_ = torch.arange(2 * N, device=device).expand(binom.shape + (2 * N,))
pad = (range_ >= N) & (range_ < N + (binom_ - binom).unsqueeze(-1))
keep = (range_ < (2 ** length).unsqueeze(-1)) & (
support.sum(-1).expand(binom.shape + (2 * N,)) == count.unsqueeze(-1)
)
support = support.expand(binom.shape + (-1, -1))[pad | keep]
support = support.view(binom.shape + (binom_, length_))
return support, binom
[docs]
def enumerate_binary_sequences_with_cardinality(
length: Any,
count: Any,
device: torch.device = torch.device("cpu"),
dtype: int = torch.long,
) -> Any:
r"""Enumerate the configurations of binary sequences with fixed sum
Parameters
----------
length
The number of elements in the binary sequence. Either a tensor or an int. Must
be the same type as `count`. If a tensor, must broadcast with `count`.
count
The number of elements with value 1. Either a tensor or an int. Must be the same
type as `length`. If a tensor, must broadcast with `length`.
device
If `length` and `count` are integers, `device` specifies the device to return
the tensor on. Otherwise the device of `length` is used.
dtype
If `length` and `count` are integers, `dtype` specifies the return type of
the tensor. Otherwise the type of `length` is used.
Returns
-------
support : torch.Tensor or tuple of torch.Tensor
If `length` and `count` are both integers, `support` is a tensor of shape
``(N, length)`` where :math:`N = \binom{length}{count}` is the number of unique
binary sequence configurations of length `length` such that for any ``n``,
``support[n].sum() == count``.
If `length` and `count` are both long tensors, `support` is a tuple of tensors
``support_, binom`` where `support_` is of shape ``(B*, N_, length_)`` and
`binom` is of shape ``(B*)``. ``B*`` refers to the broadcasted shape of `length`
and `count`, ``N_`` is the maximum value in `binom`, and ``length_`` is the
maximum value in ``length_``. For multi-index ``b``, ``support[b]`` stores the
unique binary sequence configurations for ``length[b]`` and ``count[b]``.
``binom[b]`` stores the number of unique configurations for ``length[b]`` and
``count[b]``, which is always :math:`\binom{length[b]}{count[b]}`. Sequences
are right-padded to the maximum length and count: for index ``b``, only values
in ``support[b, :binom[b], :length[b]]`` are valid.
Warnings
--------
The size of the returned support grows exponentially with `length`.
"""
if isinstance(length, torch.Tensor) and isinstance(count, torch.Tensor):
return _enumerate_binary_sequences_with_cardinality_tensor(length, count)
elif isinstance(length, int) and isinstance(count, int):
return _enumerate_binary_sequences_with_cardinality_int(
length, count, device, dtype
)
else:
raise RuntimeError("length and count must both be tensors or ints")
[docs]
class SimpleRandomSamplingWithoutReplacement(torch.distributions.ExponentialFamily):
r"""Draw binary vectors with uniform probability but fixed cardinality
`Simple Random Sampling Without Replacement
<https://en.wikipedia.org/wiki/Simple_random_sample>`__ (SRSWOR) is a uniform
distribution over binary vectors of length :math:`T` with a fixed sum :math:`L`:
.. math::
P(b|L) = I\left[\sum^T_{t=1} b_t = L\right] \frac{1}{T \mathrm{\>choose\>} L}
where :math:`I[\cdot]` is the indicator function. The distribution is a special
case of the Conditional Bernoulli [chen1994]_ and a member of the
`Exponential Family <https://en.wikipedia.org/wiki/Exponential_family>`__.
Parameters
----------
total_count
The value(s) :math:`T`. Must broadcast with `given_count`. Represents the sizes
of the sample vectors. If not all equal or less than `out_size`, samples will be
right-padded with zeros.
given_count
The value(s) :math:`L`. Must broadcast with and have values no greater than
`total_count`. Represents the cardinality constraints of the sample vectors.
out_size
The length of the binary vectors. If it exceeds some value of `total_count`,
that sample will be right-padded with zeros. Must be no less than
``total_count.max()``. If unset, defaults to that value.
validate_args
Notes
-----
The support can only be enumerated if all elements of `total_count` are equal;
likewise for `given_count`.
"""
arg_constraints = {
"total_count": constraints.nonnegative_integer,
"given_count": constraints.nonnegative_integer,
}
_mean_carrier_measure = 0
def __init__(
self,
given_count: Union[int, torch.Tensor],
total_count: Union[int, torch.Tensor],
out_size: Optional[int] = None,
validate_args: Optional[bool] = None,
):
device = None
if isinstance(given_count, torch.Tensor):
device = given_count.device
if (
isinstance(total_count, torch.Tensor)
and given_count.device != total_count.device
):
raise ValueError(
"given_count and total_count must be on the same device"
)
elif isinstance(total_count, torch.Tensor):
device = total_count.device
given_count = torch.as_tensor(given_count, device=device)
total_count = torch.as_tensor(total_count, device=device)
total_count_max = int(total_count.max().item())
if out_size is None:
out_size = total_count_max
given_count, total_count = torch.broadcast_tensors(given_count, total_count)
batch_shape = given_count.size()
event_shape = torch.Size([out_size])
self.total_count, self.given_count = total_count, given_count
super().__init__(batch_shape, event_shape, validate_args)
if self._validate_args:
given_count = argcheck.is_nonnegt(given_count, "given_count")
total_count = argcheck.is_nonnegt(total_count, "total_count")
argcheck.is_lte(given_count, total_count, "given_count", "total_count")
argcheck.is_gte(out_size, total_count, "out_size", "total_count")
@constraints.dependent_property
def support(self):
return BinaryCardinalityConstraint(
self.given_count, self.event_shape[0], self.total_count
)
@property
def has_enumerate_support(self) -> bool:
return (
(self.total_count == self.total_count.flatten()[0]).all()
& (self.given_count == self.given_count.flatten()[0]).all()
).item()
def enumerate_support(self, expand=True) -> torch.Tensor:
if not self.has_enumerate_support:
raise NotImplementedError(
"total_count must all be equal and given_count must all be equal to "
"enumerate support"
)
total = self.total_count.flatten()[0].item()
given = self.given_count.flatten()[0].item()
support = enumerate_binary_sequences_with_cardinality(
total, given, self.total_count.device, dtype=torch.float
)
out_size = self.event_shape[0]
if out_size != total:
support = torch.nn.functional.pad(support, (0, out_size - total))
support = support.view((-1,) + (1,) * len(self.batch_shape) + (out_size,))
if expand:
support = support.expand((-1,) + self.batch_shape + (out_size,))
return support
@lazy_property
def log_partition(self) -> torch.Tensor:
# log total_count choose given_count
log_factorial = (
torch.arange(
1,
self.event_shape[0] + 1,
device=self.total_count.device,
dtype=torch.float,
)
.log()
.cumsum(0)
)
t_idx = (self.total_count.long() - 1).clamp_min(0)
g_idx = (self.given_count.long() - 1).clamp_min(0)
tmg_idx = (self.total_count.long() - self.given_count.long() - 1).clamp_min(0)
return log_factorial[t_idx] - log_factorial[g_idx] - log_factorial[tmg_idx]
@property
def mean(self):
len_mask = self.total_count.unsqueeze(-1) <= torch.arange(
self.event_shape[0], device=self.total_count.device
)
return (
(self.given_count / self.total_count.clamp_min(1.0))
.unsqueeze(-1)
.expand(self.batch_shape + self.event_shape)
).masked_fill(len_mask, 0.0)
@property
def variance(self):
return self.mean * (1 - self.mean)
@property
def _natural_params(self):
return self.total_count.new_zeros(self.batch_shape + self.event_shape)
def _log_normalizer(self, logits):
if (logits == 0).all():
raise RuntimeError("Logits invalid")
return self.log_partition
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(
SimpleRandomSamplingWithoutReplacement, _instance
)
batch_shape = list(batch_shape)
new.given_count = self.given_count.expand(batch_shape)
new.total_count = self.total_count.expand(batch_shape)
if "log_partition" in self.__dict__:
new.log_partition = self.log_partition.expand(batch_shape)
super(SimpleRandomSamplingWithoutReplacement, new).__init__(
torch.Size(batch_shape), self.event_shape, validate_args=False
)
new._validate_args = self._validate_args
return new
def sample(self, sample_shape=torch.Size([])):
sample_shape = torch.Size(sample_shape)
shape = self._extended_shape(sample_shape)
with torch.no_grad():
total_count = self.total_count.expand(shape[:-1])
given_count = self.given_count.expand(shape[:-1])
b = simple_random_sampling_without_replacement(
total_count, given_count, self.event_shape[0]
)
return b
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
return (-self.log_partition).expand(value.shape[:-1])