distributions

PyTorch distributions and interfaces

Warning

Distributions cannot be JIT scripted or traced.

class pydrobert.torch.distributions.BinaryCardinalityConstraint(given_count, tmax, total_count=None)[source]

Ensures a vector of binary values sums to the required cardinality

class pydrobert.torch.distributions.ConditionalStraightThrough(batch_shape=(), event_shape=(), validate_args=None)[source]

Straight-throughs with a conditional dist on relaxed samples given discrete ones

In addition to the methods of StraightThrough, classes implementing this interface additionally allow for relaxed sampling given its discrete image csample(), and a method for determining the log probability of that conditional clog_prob().

abstract clog_prob(zcond, b)[source]

Return the log probability of a relaxed sample conditioned on a discrete one

Returns \(lp = log P(z^{cond}|b)\), where the conditional obeys the following equality:

\[\begin{split}P(z^{cond}|b)P(b) = P(z^{cond}, b) = \begin{cases} P(z^{cond}) & H(z^{cond}) = b \\ 0 & \mathrm{otherwise} \end{cases}\end{split}\]

where \(H\) is the threshold function. In other words, given a discrete sample b which is the output of some thresholded relaxed sample, what is the probability that zcond is that sample?

Parameters:
  • zcond (Tensor) – A relaxed sample.

  • b (Tensor) – A discrete sample. Usually the result of drawing a relaxed sample from this instance’s rsample() method, then applying a discrete threshold to it via threshold().

Returns:

lp (torch.Tensor) – The log probabilities of shape sample_shape + batch_shape.

abstract csample(b)[source]

Draw a relaxed sample conditioned on its thresholded (discrete) image

Parameters:

b (Tensor) – A discrete sample. Usually the result of drawing a relaxed sample from this instance’s rsample() method, then applying a discrete threshold to it via threshold().

Returns:

zcond (torch.Tensor) – A relaxed sample such that threshold(zcond) == b.

class pydrobert.torch.distributions.Density[source]

Interface for a density function

A density is a non-negative function over some domain. A density implements the method log_prob() which returns the log of the density applied to some number of samples.

Notes

While log_prob() is not necessarily a log probability for all densities, the name was chosen to match the method of torch.distributions.Distribution. All probability densities are densities.

class pydrobert.torch.distributions.GumbelOneHotCategorical(logits=None, probs=None, validate_args=None)[source]

Gumbel distributions with a categorical relaxation

This distribution should be treated as a series of independent Gumbel distributions, normalized along the final dimension of logits or probs. Samples can optionally be discretized to draws from a (one-hot) categorical distribution by taking the max Gumbel variable along the final axis. sample(), rsample(), and statistics like the mean and standard deviation are all relative to the Gumbel samples.

The relaxation, threshold, and conditional relaxed sample defined in [tucker2017]. The relaxation \(z\) is sampled as

\[\begin{split}u_{i,j} \sim \mathrm{Uniform}([0, 1]) \\ z_{i,j} = \log probs_{i,j} - \log(-\log u_{i,j})\end{split}\]

which can be transformed into a (one-hot) categorical sample via by threshold

\[\begin{split}b_{i,j} = \begin{cases} 1 & j' \neq j \implies z_{i,j} > z_{i,j'} \\ 0 & \mathrm{otherwise} \end{cases}.\end{split}\]

A relaxed sample \(z^{cond}\) conditioned on categorical sample \(b\) can be drawn by

\[\begin{split}v_{i,j} \sim \mathrm{Uniform}([0, 1]) \\ z^{cond}_{i,j} = \begin{cases} -\log(-\log v_{i,j}) & b_{i,j} = 1 \\ -\log\left( -\frac{\log v_{i,j}}{probs_{i,j}} - \log \sum_{j'} b_{i,j'} v_{i,j'} \right) & b_{i,j} = 0 \end{cases}.\end{split}\]
class pydrobert.torch.distributions.LogisticBernoulli(probs=None, logits=None, validate_args=None)[source]

A Logistic distribution which can be thresholded to Bernoulli samples

This distribution should be treated as a (normalized) Logistic distribution with the option to discretize to Bernoulli values, not the other way around. sample(), rsample(), and statistics like the mean and standard deviation are all relative to the relaxed sample.

The relaxation, threshold, and conditional relaxed sample defined in [tucker2017]. The relaxation \(z\) is sampled as

\[\begin{split}u_i \sim \mathrm{Uniform}([0, 1]) \\ z_i = logits_i + log(u_i) - log (1 - u_i)\end{split}\]

which can be transformed into a Bernoulli sample by threshold

\[\begin{split}b_i = \begin{cases} 1 & z_i >= 0 \\ 0 & z_i < 0 \end{cases}.\end{split}\]

A relaxed sample \(z^{cond}\) conditioned on the Bernoulli sample \(b\) can be drawn by

\[\begin{split}v_i \sim \mathrm{Uniform}([0, 1]) \\ z^{cond}_i = \begin{cases} \log\left(\frac{v_i}{(1 - v_i)(1 - probs_i)} + 1 \right) & b_i = 1 \\ -\log\left(\frac{v_i}{(1 - v_i)probs_i} + 1\right) & b_i = 0 \end{cases}.\end{split}\]
class pydrobert.torch.distributions.SequentialLanguageModelDistribution(random_walk, batch_size=None, initial_state=None, max_iters=None, cache_samples=False, validate_args=None)[source]

A SequentialLanguageModel as a Distribution

This class wraps a pydrobert.torch.modules.RandomWalk instance, itself wrapping a pydrobert.torch.modules.SequentialLanguageModel, treating it as a torch.distributions.distribution.Distribution. It relies on the walk to sample.

Among other things, the resulting distribution can be passed as an argument to an pydrobert.torch.estimators.Estimator.

Parameters:
  • random_walk (RandomWalk) – The RandomWalk instance with language model random_walk.lm.

  • batch_shape – The batch shape to use when calling the underlying language model or the walk. If empty, the number of samples being drawn (or passed to log_prob()) is treated as the batch size. See the below note for more information.

  • initial_state (Optional[Dict[str, Tensor]]) – If specified, any calls to the underlying language model or the walk will be passed this value.

  • max_iters (Optional[int]) – Specifies the maximum number of steps to take in the walk. Either random_walk.lm.eos or max_iters must be set.

  • cache_samples (bool) –

    If True, calls to sample() or log_prob() will save the last samples and their log probabilities. This can avoid expensive recomputations if, for example, the log probability of a sample is always queried after it is sampled:

    >>> sample = dist.sample()
    >>> log_prob = dist.log_prob(sample)
    

    The cache is stored until a new sample takes its place or it is manually cleared with clear_cache(). See the below warning for complications with the cache.

  • validate_args (Optional[bool]) –

Warning

This wrapper does not handle any changes to the distribution which may occur for subclasses of RandomWalk with non-default implementations of pydrobert.torch.modules.RandomWalk.update_log_probs_for_step(). log_prob() will in general reflect the default, unadjusted log probabilities. The situation is complicated if cache_samples is enabled: if sample() is called when enabled, the adjusted log probabilities are cached, but if log_prob() is called prior to caching, the unadjusted log probabilities are cached.

In short, do not use custom RandomWalk instances with this class unless you know what you’re doing.

Notes

We expect most SequentialLanguageModel instances to be able to handle an arbitrary number of sequences at once. In this case, batch_shape should be left empty (its default). In this case the samples will be flattened into a single batch dimension before being passed to the underlying language model. For example:

>>> dist = SequentialLanguageModelDistribution(walk)
>>> sample = dist.sample()  # a single sequence. Of shape (sequence_length,)
>>> sample = dist.sample([M])  # M sequences. Of shape (M, sequence_length)

However, some language models will require sampling or computing the log probabilities of a fixed number of samples at a time, particularly when there’s some implicit conditioning on some other batched input passed via initial_state. For example, an acoustic model for ASR will condition its sequences on some batched audio or feature input. An NMT system will condition its target sequence output on batched source sequence input. In this case, batch_shape can be set to a 1-dimensional shape containing the number of batch elements, like so:

>>> dist = SequentialLanguageModelDistribution(walk, [N], initial_state)
>>> sample = dist.sample()  # N sequences, 1 per batch elem. (N, sequence_length)
>>> sample = dist.sample([M])  # M * N sequences, M /batch elem. (N, M, seq_length)

To accomplish this, the walk/lm will be queried M times sequentially with batch size N, the results stacked (and padded with eos, if necessary).

Since the batch_shape method performs sequential sampling along M, it will tend to be slower than sampling M * N samples via the other method. However, sequential sampling will also tend to have a smaller memory footprint.

clear_cache()[source]

Manually clear the sample cache

class pydrobert.torch.distributions.SimpleRandomSamplingWithoutReplacement(given_count, total_count, out_size=None, validate_args=None)[source]

Draw binary vectors with uniform probability but fixed cardinality

Simple Random Sampling Without Replacement (SRSWOR) is a uniform distribution over binary vectors of length \(T\) with a fixed sum \(L\):

\[P(b|L) = I\left[\sum^T_{t=1} b_t = L\right] \frac{1}{T \mathrm{\>choose\>} L}\]

where \(I[\cdot]\) is the indicator function. The distribution is a special case of the Conditional Bernoulli [chen1994] and a member of the Exponential Family.

Parameters:
  • total_count (Union[int, Tensor]) – The value(s) \(T\). Must broadcast with given_count. Represents the sizes of the sample vectors. If not all equal or less than out_size, samples will be right-padded with zeros.

  • given_count (Union[int, Tensor]) – The value(s) \(L\). Must broadcast with and have values no greater than total_count. Represents the cardinality constraints of the sample vectors.

  • out_size (Optional[int]) – The length of the binary vectors. If it exceeds some value of total_count, that sample will be right-padded with zeros. Must be no less than total_count.max(). If unset, defaults to that value.

  • validate_args (Optional[bool]) –

Notes

The support can only be enumerated if all elements of total_count are equal; likewise for given_count.

class pydrobert.torch.distributions.StraightThrough(batch_shape=(), event_shape=(), validate_args=None)[source]

Interface for distributions for which a straight through estimate is possible

Classes implementing this interface supply both a method for drawing a relaxed sample rsample() (dist.has_rsample == True) and a method for thresholding it into a discrete sample threshold().

abstract threshold(z, straight_through=False)[source]

Convert a relaxed sample into a discrete sample

Parameters:
  • z (Tensor) – A relaxed sample, usually drawn via this instance’s rsample() method.

  • straight_through (bool, optional) – If true, attach the gradient of z to the discrete sample.

Returns:

b (torch.Tensor) – The discrete sample acquired by applying a threshold function to z.

abstract tlog_prob(b)[source]

The log probability of a thresholded sample

Parameters:

b (Tensor) – A discrete sample. Usually the result of drawing a relaxed sample from this instance’s rsample() method, then applying a discrete threshold to it via threshold().

Returns:

lp (torch.Tensor) – The log probability of the sample. Of shape sample_size + batch_size.

class pydrobert.torch.distributions.TokenSequenceConstraint(vocab_size, eos=None, max_iters=None)[source]

Distribution constraint for token sequences

A token sequence is a vector which can have integer values ranging between [0, vocab_size - 1]. A token sequence must be completed, which requires at least one of the two following conditions to be met:

  • The sequence dimension matches max_iters.

  • The sequence dimension is greater than 0, less than max_iters if set, and each sequence contains at least one eos.

If eos is included, any value beyond the first eos in each sequence is ignored.