distributions
PyTorch distributions and interfaces
Warning
Distributions cannot be JIT scripted or traced.
- class pydrobert.torch.distributions.BinaryCardinalityConstraint(given_count, tmax, total_count=None)[source]
Ensures a vector of binary values sums to the required cardinality
- class pydrobert.torch.distributions.ConditionalStraightThrough(batch_shape=(), event_shape=(), validate_args=None)[source]
Straight-throughs with a conditional dist on relaxed samples given discrete ones
In addition to the methods of
StraightThrough
, classes implementing this interface additionally allow for relaxed sampling given its discrete imagecsample()
, and a method for determining the log probability of that conditionalclog_prob()
.- abstract clog_prob(zcond, b)[source]
Return the log probability of a relaxed sample conditioned on a discrete one
Returns \(lp = log P(z^{cond}|b)\), where the conditional obeys the following equality:
\[\begin{split}P(z^{cond}|b)P(b) = P(z^{cond}, b) = \begin{cases} P(z^{cond}) & H(z^{cond}) = b \\ 0 & \mathrm{otherwise} \end{cases}\end{split}\]where \(H\) is the threshold function. In other words, given a discrete sample b which is the output of some thresholded relaxed sample, what is the probability that zcond is that sample?
- Parameters:
- Returns:
lp (
torch.Tensor
) – The log probabilities of shapesample_shape + batch_shape
.
- abstract csample(b)[source]
Draw a relaxed sample conditioned on its thresholded (discrete) image
- Parameters:
b (
Tensor
) – A discrete sample. Usually the result of drawing a relaxed sample from this instance’srsample()
method, then applying a discrete threshold to it viathreshold()
.- Returns:
zcond (
torch.Tensor
) – A relaxed sample such thatthreshold(zcond) == b
.
- class pydrobert.torch.distributions.Density[source]
Interface for a density function
A density is a non-negative function over some domain. A density implements the method
log_prob()
which returns the log of the density applied to some number of samples.Notes
While
log_prob()
is not necessarily a log probability for all densities, the name was chosen to match the method oftorch.distributions.Distribution
. All probability densities are densities.
- class pydrobert.torch.distributions.GumbelOneHotCategorical(logits=None, probs=None, validate_args=None)[source]
Gumbel distributions with a categorical relaxation
This distribution should be treated as a series of independent Gumbel distributions, normalized along the final dimension of logits or probs. Samples can optionally be discretized to draws from a (one-hot) categorical distribution by taking the max Gumbel variable along the final axis.
sample()
,rsample()
, and statistics like the mean and standard deviation are all relative to the Gumbel samples.The relaxation, threshold, and conditional relaxed sample defined in [tucker2017]. The relaxation \(z\) is sampled as
\[\begin{split}u_{i,j} \sim \mathrm{Uniform}([0, 1]) \\ z_{i,j} = \log probs_{i,j} - \log(-\log u_{i,j})\end{split}\]which can be transformed into a (one-hot) categorical sample via by threshold
\[\begin{split}b_{i,j} = \begin{cases} 1 & j' \neq j \implies z_{i,j} > z_{i,j'} \\ 0 & \mathrm{otherwise} \end{cases}.\end{split}\]A relaxed sample \(z^{cond}\) conditioned on categorical sample \(b\) can be drawn by
\[\begin{split}v_{i,j} \sim \mathrm{Uniform}([0, 1]) \\ z^{cond}_{i,j} = \begin{cases} -\log(-\log v_{i,j}) & b_{i,j} = 1 \\ -\log\left( -\frac{\log v_{i,j}}{probs_{i,j}} - \log \sum_{j'} b_{i,j'} v_{i,j'} \right) & b_{i,j} = 0 \end{cases}.\end{split}\]
- class pydrobert.torch.distributions.LogisticBernoulli(probs=None, logits=None, validate_args=None)[source]
A Logistic distribution which can be thresholded to Bernoulli samples
This distribution should be treated as a (normalized) Logistic distribution with the option to discretize to Bernoulli values, not the other way around.
sample()
,rsample()
, and statistics like the mean and standard deviation are all relative to the relaxed sample.The relaxation, threshold, and conditional relaxed sample defined in [tucker2017]. The relaxation \(z\) is sampled as
\[\begin{split}u_i \sim \mathrm{Uniform}([0, 1]) \\ z_i = logits_i + log(u_i) - log (1 - u_i)\end{split}\]which can be transformed into a Bernoulli sample by threshold
\[\begin{split}b_i = \begin{cases} 1 & z_i >= 0 \\ 0 & z_i < 0 \end{cases}.\end{split}\]A relaxed sample \(z^{cond}\) conditioned on the Bernoulli sample \(b\) can be drawn by
\[\begin{split}v_i \sim \mathrm{Uniform}([0, 1]) \\ z^{cond}_i = \begin{cases} \log\left(\frac{v_i}{(1 - v_i)(1 - probs_i)} + 1 \right) & b_i = 1 \\ -\log\left(\frac{v_i}{(1 - v_i)probs_i} + 1\right) & b_i = 0 \end{cases}.\end{split}\]
- class pydrobert.torch.distributions.SequentialLanguageModelDistribution(random_walk, batch_size=None, initial_state=None, max_iters=None, cache_samples=False, validate_args=None)[source]
A SequentialLanguageModel as a Distribution
This class wraps a
pydrobert.torch.modules.RandomWalk
instance, itself wrapping apydrobert.torch.modules.SequentialLanguageModel
, treating it as atorch.distributions.distribution.Distribution
. It relies on the walk to sample.Among other things, the resulting distribution can be passed as an argument to an
pydrobert.torch.estimators.Estimator
.- Parameters:
random_walk (
RandomWalk
) – TheRandomWalk
instance with language modelrandom_walk.lm
.batch_shape – The batch shape to use when calling the underlying language model or the walk. If empty, the number of samples being drawn (or passed to
log_prob()
) is treated as the batch size. See the below note for more information.initial_state (
Optional
[Dict
[str
,Tensor
]]) – If specified, any calls to the underlying language model or the walk will be passed this value.max_iters (
Optional
[int
]) – Specifies the maximum number of steps to take in the walk. Eitherrandom_walk.lm.eos
or max_iters must be set.cache_samples (
bool
) –If
True
, calls tosample()
orlog_prob()
will save the last samples and their log probabilities. This can avoid expensive recomputations if, for example, the log probability of a sample is always queried after it is sampled:>>> sample = dist.sample() >>> log_prob = dist.log_prob(sample)
The cache is stored until a new sample takes its place or it is manually cleared with
clear_cache()
. See the below warning for complications with the cache.
Warning
This wrapper does not handle any changes to the distribution which may occur for subclasses of
RandomWalk
with non-default implementations ofpydrobert.torch.modules.RandomWalk.update_log_probs_for_step()
.log_prob()
will in general reflect the default, unadjusted log probabilities. The situation is complicated if cache_samples is enabled: ifsample()
is called when enabled, the adjusted log probabilities are cached, but iflog_prob()
is called prior to caching, the unadjusted log probabilities are cached.In short, do not use custom
RandomWalk
instances with this class unless you know what you’re doing.Notes
We expect most
SequentialLanguageModel
instances to be able to handle an arbitrary number of sequences at once. In this case, batch_shape should be left empty (its default). In this case the samples will be flattened into a single batch dimension before being passed to the underlying language model. For example:>>> dist = SequentialLanguageModelDistribution(walk) >>> sample = dist.sample() # a single sequence. Of shape (sequence_length,) >>> sample = dist.sample([M]) # M sequences. Of shape (M, sequence_length)
However, some language models will require sampling or computing the log probabilities of a fixed number of samples at a time, particularly when there’s some implicit conditioning on some other batched input passed via initial_state. For example, an acoustic model for ASR will condition its sequences on some batched audio or feature input. An NMT system will condition its target sequence output on batched source sequence input. In this case, batch_shape can be set to a 1-dimensional shape containing the number of batch elements, like so:
>>> dist = SequentialLanguageModelDistribution(walk, [N], initial_state) >>> sample = dist.sample() # N sequences, 1 per batch elem. (N, sequence_length) >>> sample = dist.sample([M]) # M * N sequences, M /batch elem. (N, M, seq_length)
To accomplish this, the walk/lm will be queried
M
times sequentially with batch sizeN
, the results stacked (and padded with eos, if necessary).Since the batch_shape method performs sequential sampling along
M
, it will tend to be slower than samplingM * N
samples via the other method. However, sequential sampling will also tend to have a smaller memory footprint.
- class pydrobert.torch.distributions.SimpleRandomSamplingWithoutReplacement(given_count, total_count, out_size=None, validate_args=None)[source]
Draw binary vectors with uniform probability but fixed cardinality
Simple Random Sampling Without Replacement (SRSWOR) is a uniform distribution over binary vectors of length \(T\) with a fixed sum \(L\):
\[P(b|L) = I\left[\sum^T_{t=1} b_t = L\right] \frac{1}{T \mathrm{\>choose\>} L}\]where \(I[\cdot]\) is the indicator function. The distribution is a special case of the Conditional Bernoulli [chen1994] and a member of the Exponential Family.
- Parameters:
total_count (
Union
[int
,Tensor
]) – The value(s) \(T\). Must broadcast with given_count. Represents the sizes of the sample vectors. If not all equal or less than out_size, samples will be right-padded with zeros.given_count (
Union
[int
,Tensor
]) – The value(s) \(L\). Must broadcast with and have values no greater than total_count. Represents the cardinality constraints of the sample vectors.out_size (
Optional
[int
]) – The length of the binary vectors. If it exceeds some value of total_count, that sample will be right-padded with zeros. Must be no less thantotal_count.max()
. If unset, defaults to that value.
Notes
The support can only be enumerated if all elements of total_count are equal; likewise for given_count.
- class pydrobert.torch.distributions.StraightThrough(batch_shape=(), event_shape=(), validate_args=None)[source]
Interface for distributions for which a straight through estimate is possible
Classes implementing this interface supply both a method for drawing a relaxed sample
rsample()
(dist.has_rsample == True
) and a method for thresholding it into a discrete samplethreshold()
.- abstract threshold(z, straight_through=False)[source]
Convert a relaxed sample into a discrete sample
- Parameters:
- Returns:
b (
torch.Tensor
) – The discrete sample acquired by applying a threshold function to z.
- abstract tlog_prob(b)[source]
The log probability of a thresholded sample
- Parameters:
b (
Tensor
) – A discrete sample. Usually the result of drawing a relaxed sample from this instance’srsample()
method, then applying a discrete threshold to it viathreshold()
.- Returns:
lp (
torch.Tensor
) – The log probability of the sample. Of shapesample_size + batch_size
.
- class pydrobert.torch.distributions.TokenSequenceConstraint(vocab_size, eos=None, max_iters=None)[source]
Distribution constraint for token sequences
A token sequence is a vector which can have integer values ranging between
[0, vocab_size - 1]
. A token sequence must be completed, which requires at least one of the two following conditions to be met:The sequence dimension matches max_iters.
The sequence dimension is greater than 0, less than max_iters if set, and each sequence contains at least one eos.
If eos is included, any value beyond the first eos in each sequence is ignored.