# Copyright 2022 Sean Robertson
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import abc
from typing import Optional, Sequence, Tuple
import torch
from . import config, argcheck
from .distributions import Density, StraightThrough, ConditionalStraightThrough
from ._estimators import Estimator, FunctionOnSample
from ._compat import logaddexp
[docs]
class MonteCarloEstimator(Estimator, metaclass=abc.ABCMeta):
r"""A Monte Carlo estimator base class
A Monte Carlo estimator estimates an expectation with some form of random sampling
from a proposal. Letting :math:`N` represent some number of Monte Carlo samples,
:math:`b^{(1:N)}` represent `N` samples drawn from a proposal (usually but not
necessarily :math:`P`), the estimator :func:`G` is unbiased iff
.. math::
\mathbb{E}_{b^{(1:N)} \sim P}[G(b)] = \mathbb{E}_{b \sim P}[f(b)].
A toy example can be found in the `source repository
<https://github.com/sdrobert/pydrobert-pytorch/blob/master/tests/test_mc.py>`_ under
the test name ``test_benchmark``. It can be run from the repository root as
.. code-block:: shell
DO_MC_BENCHMARK=1 pytest tests/test_mc.py -k benchmark -s
Parameters
----------
proposal
func
mc_samples
The number of samples to draw from `proposal`, :math:`N`.
is_log
If :obj:`True`, operate in log space. `func` defines :math:`\log f` instead of
:math:`f` and the return value `v` represents an estimate of :math:`\log v`. The
estimate of :math:`\log v` is simply the log of the MC estimate of :math:`v`,
which is a biased estimate. The estimate can be proven to be a lower bound
of :math:`\log v` via Jensen's inequality. To recover an unbiased estimate
of :math:`v`, one may exponentiate the return value.
Returns
-------
v : torch.Tensor
"""
mc_samples: int
def __init__(
self,
proposal: torch.distributions.Distribution,
func: FunctionOnSample,
mc_samples: int,
is_log: bool = False,
) -> None:
proposal = argcheck.is_a(proposal, torch.distributions.Distribution, "proposal")
mc_samples = argcheck.is_posi(mc_samples, "mc_samples")
is_log = argcheck.is_bool(is_log, "is_log")
super().__init__(proposal, func, is_log)
self.mc_samples = mc_samples
[docs]
class DirectEstimator(MonteCarloEstimator):
r"""Direct MC estimate using REINFORCE gradient estimate
The expectation :math:`v = \mathbb{E}_{b \sim P}[f(b)]` is estimated by drawing
:math:`N` samples :math:`b^{(1:N)}` i.i.d. from :math:`P` and taking the sample
average:
.. math::
v \approx \frac{1}{N} \sum_{n=1}^N f\left(b^{(n)}\right).
An optional control variate :math:`c` can be specified:
.. math::
v \approx \frac{1}{N} \sum_{n=1}^N
f\left(b^{(n)}\right) - c\left(b^{(n)}\right) + \mu_c
which is unbiased when :math:`\mathbb{E}_{b \sim P}[c(b)] = \mu_c`.
In the backward pass, the gradient of the expectation is estimated using REINFORCE
[williams1992]_:
.. math::
\nabla v \approx \frac{1}{N} \sum_{n=1}^N \nabla
\left(f\left(b^{(n)}\right) - c\left(b^{(n)}\right) + \mu_c\right)\log P(b).
With the control variate terms excluded if they were not specified.
Parameters
----------
proposal
func
mc_samples
The number of samples to draw from `proposal`, :math:`N`.
cv
The function :math:`c`.
cv_mean
The value :math:`\mu_c`.
is_log
Returns
-------
v : torch.Tensor
"""
cv: Optional[FunctionOnSample]
cv_mean: Optional[torch.Tensor]
def __init__(
self,
proposal: torch.distributions.Distribution,
func: FunctionOnSample,
mc_samples: int,
cv: Optional[FunctionOnSample] = None,
cv_mean: Optional[torch.Tensor] = None,
is_log: bool = False,
):
cv_mean = argcheck.is_tensor(cv_mean, "cv_mean", True)
super().__init__(proposal, func, mc_samples, is_log)
self.cv, self.cv_mean = cv, cv_mean
def __call__(self) -> torch.Tensor:
b = self.proposal.sample([self.mc_samples])
fb = self.func(b)
if self.is_log:
fb_lmax = fb.max(0, keepdim=True)[0].clamp(
torch.finfo(fb.dtype).min / 2, torch.finfo(fb.dtype).max / 2
)
fb = (fb - fb_lmax).clamp(config.EPS_NINF, config.EPS_INF).exp()
if self.cv is not None:
c = self.cv_mean
cvb = self.cv(b)
if self.is_log:
c = (
(c.unsqueeze(0) - fb_lmax)
.clamp(config.EPS_NINF, config.EPS_INF)
.exp()
)
cvb = (cvb - fb_lmax).clamp(config.EPS_NINF, config.EPS_INF).exp()
fb = fb - cvb + c
log_pb = self.proposal.log_prob(b)
deriv = (fb.detach() * log_pb).mean(0)
fb = fb.mean(0)
if self.is_log:
fb = fb.clamp_min(math.exp(config.EPS_NINF))
deriv = deriv / fb.detach()
v = fb.log() + deriv - deriv.detach() + fb_lmax
else:
v = fb + deriv - deriv.detach()
return v
[docs]
class ReparameterizationEstimator(MonteCarloEstimator):
r"""MC estimation of continuous variables with reparameterization gradient
This estimator applies to distributions over continuous random variables :math:`z
\sim P` whose values can be decomposed into the sum of a deterministic, learnable
(i.e. with gradient) part :math:`\theta` and a random, unlearnable part
:math:`\epsilon`:
.. math::
z = \theta + \epsilon,\>\epsilon \sim P' \\
\nabla P(z) = \nabla P'(\epsilon) = 0
The expectation :math:`v = \mathbb{E}_{z \sim P}[f(z)]` is estimated by drawing
:math:`N` samples :math:`z^{(1:N)}` i.i.d. from :math:`P` and taking the sample
average:
.. math::
v \approx \frac{1}{N} \sum^N_{n=1} f\left(z^{(n)}\right).
We can ignore the probabilities in the bacward direction because :math:`\nabla P(z)
= 0`, leaving the unbiased estimate of the gradient:
.. math::
\nabla v \approx \frac{1}{N} \sum^N_{n=1} \nabla f\left(z^{(n)}\right).
Parameters
----------
proposal
The distribution over which the expectation is taken, :math:`P` (not
:math:`P'`). `proposal` must implement the :func:`Distribution.rsample` method
(``proposal.has_rsample == True``).
func
mc_samples
is_log
Returns
-------
v : torch.Tensor
"""
def __init__(
self,
proposal: torch.distributions.Distribution,
func: FunctionOnSample,
mc_samples: int,
is_log: bool = False,
) -> None:
if not proposal.has_rsample:
raise ValueError("proposal must implement rsample")
super().__init__(proposal, func, mc_samples, is_log)
def __call__(self) -> torch.Tensor:
z = self.proposal.rsample([self.mc_samples])
fz = self.func(z)
return fz.logsumexp(0) - math.log(fz.size(0)) if self.is_log else fz.mean(0)
[docs]
class StraightThroughEstimator(MonteCarloEstimator):
r"""MC estimation of discrete variables with continuous relaxation's reparam grad
A straight-through estimator [bengio2013]_ is like a
:class:`ReparameterizationEstimator` but fudges the fact that the samples are
actually discrete to compute the gradient. To estimate :math:`v = \mathbb{E}_{b \sim
P}[f(b)]`, we need a distribution over discrete samples' continuous relaxations,
.. math::
z = \theta + \epsilon,\>\epsilon \sim P'
and a threshold function :math:`H(z) = b` such that :math:`P(H(z)) = P(b)`. The
estimate of :math:`v` is computed by drawing :math:`N` relaxed values
:math:`z^{(1:N)}` and taking the sample average on thresholded values:
.. math::
v \approx \frac{1}{N} \sum^N_{n=1} f\left(H\left(z^{(n)}\right)\right).
This estimate is unbiased. In the backward direction, we approximate
.. math::
\nabla P(H(z)) \approx \nabla P(z) = \nabla P'(\epsilon) = 0
and end up with a biased estimate of the gradient resembling that of
:class:`ReparameterizationEstimator`:
.. math::
\nabla v \approx \frac{1}{N} \sum^N_{n=1}
\nabla \left(H\left(z^{(n)}\right)\right).
Parameters
----------
proposal
The distribution over which the expectation is taken, :math:`P` (not
:math:`P'`). `proposal` must implement
:class:`pydrobert.torch.distributions.StraightThrough`.
func
mc_samples
is_log
Returns
-------
v : torch.Tensor
"""
proposal: StraightThrough
def __init__(
self,
proposal: StraightThrough,
func: FunctionOnSample,
mc_samples: int,
is_log: bool = False,
) -> None:
proposal = argcheck.is_a(proposal, StraightThrough, "proposal")
super().__init__(proposal, func, mc_samples, is_log)
def __call__(self) -> torch.Tensor:
z = self.proposal.rsample([self.mc_samples])
b = self.proposal.threshold(z, True)
fb = self.func(b)
return fb.logsumexp(0) - math.log(fb.size(0)) if self.is_log else fb.mean(0)
[docs]
class ImportanceSamplingEstimator(MonteCarloEstimator):
r"""Importance Sampling MC estimate
The expectation :math:`v = \mathbb{E}_{b \sim P}[f(b)]` is estimated by drawing
:math:`N` samples :math:`b^{(1:N)}` i.i.d. from the proposal distribution :math:`Q`,
weighing the values :math:`f(b)` according to the likelihood ratio of
:math:`P(b)` over :math:`Q(b)`, and taking the sample average:
.. math::
v \approx \frac{1}{N} \sum_{n=1}^N w_n f\left(b^{(n)}\right) \\
w_n = \frac{P\left(b^{(n)}\right)}{Q\left(b^{(n)}\right)}.
The estimate is unbiased iff :math:`Q` dominates :math:`P`, that is
.. math::
\forall b \quad P(b) > 0 \implies Q(b) > 0.
The gradient is estimated as
.. math::
\nabla v \approx \frac{1}{N} \sum_{n=1}^N \frac{1}{Q\left(b^{(n)}\right)}
\nabla P\left(b^{(n)}\right)f\left(b^{(n)}\right).
If `self_normalized` is set to :obj:`True`, :math:`v` is instead estimated as:
.. math::
v \approx \sum_{n=1}^N \omega_n f\left(b^{(n)}\right) \\
\omega_n = \frac{w_n}{\sum_{n'=1}^N w_{n'}}
with gradients defined as
.. math::
\nabla v \approx \nabla \sum_{n=1}^N \omega_n f\left(b^{(n)}\right).
The self-normalized estimate is biased but with decreasing bias (assuming the
proposal dominates) as :math:`N \to \infty`. This property holds even if :math:`P`
is not a probability density (i.e. :math:`\sum_b P(b) \neq 1`).
Parameters
----------
proposal
The distribution over which the expectation is taken. In this case, `proposal`
has probability density :math:`Q`, not :math:`P`.
func
mc_samples
density
The density :math:`P`. Can be unnormalized.
self_normalize
Whether to use the self-normalized estimator.
is_log
Returns
-------
v : torch.Tensor
Notes
-----
The gradient with respect to the proposal's parameters is always set to :math:`0`.
This is an unbiased estimate for :math:`v` and its unnormalized IS estimator. It is
a biased estimate of the gradient of the self-normalized estimate, but only because
the self-normalized estimate is biased itself.
"""
density: Density
self_normalize: bool
def __init__(
self,
proposal: torch.distributions.Distribution,
func: FunctionOnSample,
mc_samples: int,
density: Density,
self_normalize: bool = False,
is_log: bool = False,
) -> None:
self_normalize = argcheck.is_bool(self_normalize, "self_normalize")
super().__init__(proposal, func, mc_samples, is_log)
self.density = density
self.self_normalize = self_normalize
def __call__(self) -> torch.Tensor:
b = self.proposal.sample([self.mc_samples])
lpb = self.density.log_prob(b)
lqb = self.proposal.log_prob(b)
lqb = lqb.detach() + 0 * lqb.sum()
fb = self.func(b)
if self.self_normalize:
llr = (lpb - lqb).log_softmax(0)
else:
llr = lpb - lqb - math.log(self.mc_samples)
if self.is_log:
v = (fb + llr).logsumexp(0)
else:
v = (fb * llr.exp()).sum(0)
return v
[docs]
class RelaxEstimator(MonteCarloEstimator):
r"""RELAX estimator
The RELAX estimator [grathwohl2017]_ estimates the expectation :math:`v =
\mathbb{E}_{b \sim P}[f(b)]` for discrete :math:`b` via MC estimation, attempting to
minimize the variance of the estimator using control variates over their continuous
relaxations.
Let :math:`z^{(1:N)}` be :math:`N` continous relaxation variables drawn i.i.d.
:math:`z^{(n)} \sim P(\cdot)` and :math:`b^{(1:N)}` be their discretizations
:math:`H(z^{(n)}) = b^{(n)}`. Let :math:`P(z|b)` be a conditional s.t. the joint
distribution satisfies :math:`P(b)P(z|b) = P(z, b) = P(z)I[H(z) = b]` and
:math:`\tilde{z}^{(1:N)}` be samples drawn from those conditionals
:math:`\tilde{z}^{(n)} \sim P(\cdot|b^{(n)})`. Let :math:`c` be some control variate
accepting relaxed samples. Then the (unbiased) RELAX estimator is:
.. math::
v \approx \frac{1}{N} \sum^N_{n=1}
f\left(b^{(n)}\right) - c\left(\tilde{z}^{(n)}\right)
+ c\left(z^{(n)}\right).
Pairing this estimator with one of the REBAR control variates from
:mod:`pydrobert.torch.modules` yields the REBAR estimator [tucker2017]_.
We offer two ways of estimating the gradient :math:`\nabla z`. The first is a
REINFORCE-style estimate:
.. math::
\nabla v \approx \frac{1}{N} \sum^N_{n=1} \nabla \left(
\left(f\left(b^{(n)}\right) - c\left(\tilde{z}^{(n)}\right)\right)\log P(b)
+ c\left(z^{(n)}\right)\right).
The above estimate requires no special consideration for any variable for which the
gradient is being calculated. The second, following [grathwohl2017]_, specially
optimizes the control variate parameters to minimize the variance of the gradient
estimates of the parameters involved in drawing :math:`z`. Let :math:`\theta_{1:K}`
be the set of such parameters, :math:`g_{\theta_k} \approx \nabla_{\theta_k} v` be a
REINFORCE-style estimate of the :math:`k`-th :math:`z` parameter using the equation
above, and let :math:`\gamma` be a control variate parameter. Then the
variance-minimizing loss can be approximated by:
.. math::
\nabla_\gamma \mathrm{Var}(v) \approx \frac{1}{K} \nabla_\gamma
\left(\sum_{k=1}^K g^2_{\theta_k}\right).
The remaining parameters are calculated with the REINFORCE-style estimator above.
The proposal parameters `proposal_params` and control variate parameters `cv_params`
must be specified to use this loss function.
Parameters
----------
proposal : ConditionalStraightThrough
The distribution over which the expectation is taken, :math:`P`. Must implement
:class:`pydrobert.torch.distributions.ConditionalStraightThrough`.
func
mc_samples
cv
proposal_params
A sequence of parameters used in the computation of :math:`z` and
:math:`P(H(z)`. Does not have to be specified unless using the
variance-minimizing control variate objective. If non-empty, `cv_params` must be
non-empty as well.
cv_params
A sequence of parameters used in the computation of control variate values. Does
not have to be specified unless using the variance-minimizing control variate
objective. If non-empty, `proposal_params` must be non-empty as well.
is_log
Returns
-------
v : torch.Tensor
Warnings
--------
The current implmentation does not support auxiliary loss functions for the
control variate parameters when the variance-minimizing objective is used
(`proposal_params` and `cv_params` are specified). Auxiliary loss functions for
parameters other than `cv_params` are fine.
"""
proposal: ConditionalStraightThrough
cv: FunctionOnSample
proposal_params: Tuple[torch.Tensor, ...]
cv_params: Tuple[torch.Tensor, ...]
def __init__(
self,
proposal: ConditionalStraightThrough,
func: FunctionOnSample,
mc_samples: int,
cv: FunctionOnSample,
proposal_params: Sequence[torch.Tensor] = tuple(),
cv_params: Sequence[torch.Tensor] = tuple(),
is_log: bool = False,
) -> None:
proposal = argcheck.is_a(proposal, ConditionalStraightThrough, "proposal")
proposal_params = tuple(proposal_params)
cv_params = tuple(cv_params)
if (len(proposal_params) > 0) != (len(cv_params) > 0):
raise ValueError(
"either both proposal_params and cv_params must be specified or neither"
)
super().__init__(proposal, func, mc_samples, is_log)
self.cv, self.proposal_params, self.cv_params = cv, proposal_params, cv_params
def __call__(self) -> torch.Tensor:
z = self.proposal.rsample([self.mc_samples])
b = self.proposal.threshold(z)
zcond = self.proposal.csample(b)
log_pb = self.proposal.tlog_prob(b)
fb = self.func(b)
cvz = self.cv(z)
cvzcond = self.cv(zcond)
if self.is_log:
fb_lmax = fb.max(0, keepdim=True)[0].clamp(
torch.finfo(fb.dtype).min / 2, torch.finfo(fb.dtype).max / 2
)
fb = (fb - fb_lmax).clamp(config.EPS_NINF, config.EPS_INF).exp()
cvz = (cvz - fb_lmax).clamp(config.EPS_NINF, config.EPS_INF).exp()
cvzcond = (cvzcond - fb_lmax).clamp(config.EPS_NINF, config.EPS_INF).exp()
fb_cvzcond = fb - cvzcond
if self.cv_params:
v_ = (fb_cvzcond * log_pb + cvz).mean(0)
if self.is_log:
v_ = v_.log() + fb_lmax
gs_proposal = torch.autograd.grad(
v_,
self.proposal_params,
torch.ones_like(v_),
create_graph=True,
retain_graph=True,
)
gs_cv = [0.0] * len(self.cv_params)
for gp in gs_proposal:
gp = gp.norm(2)
gs_cv = [
x + y
for (x, y) in zip(
gs_cv,
torch.autograd.grad(gp, self.cv_params, retain_graph=True),
)
]
for gc, c in zip(gs_cv, self.cv_params):
_attach_grad(c, gc.detach())
deriv = fb_cvzcond.detach() * log_pb
fb = (fb_cvzcond + cvz).mean(0)
if self.is_log:
fb = fb.clamp_min(math.exp(config.EPS_NINF))
deriv = deriv / fb.detach()
v = fb.log() + deriv - deriv.detach() + fb_lmax
else:
v = fb + deriv - deriv.detach()
return v.mean(0)
[docs]
class IndependentMetropolisHastingsEstimator(MonteCarloEstimator):
r"""Independent Metropolis Hastings MCMC estimator
Independent Metropolis Hastings (IMH) is a Markov Chain Monte Carlo (MCMC) technique
for estimating the value :math:`v = \mathbb{E}_{b \sim P}[f(b)] < \infty`. A
Markov Chain of :math:`N` samples :math:`b^{(1:N)}` is contructed sequentially by
iteratively drawing samples from a proposal :math:`b' \sim Q` and either accepting
it or rejecting it and taking :math:`b^{(n-1)}` as the next sample in the chain,
:math:`b^{(n)}`, according to the following rules:
.. math::
u \sim \mathrm{Uniform}([0, 1]) \\
b^{(n)} = \begin{cases}
b' & \alpha(b', b^{(n-1)}) > u \\
b^{(n-1)} & \mathrm{otherwise}
\end{cases} \\
\alpha(b', b^{(n-1)}) = \min\left(
\frac{P(b')Q(b^{(n-1)})}{P(b^{(n-1)}Q(b'))}, 1\right).
The sample estimate from the Markov Chain
.. math::
v \approx \frac{1}{N - M} \sum_{n=M + 1}^N f\left(b^{(n)}\right)
for a fixed number of burn-in samples :math:`M \in [0, N)` is biased but converges
asymptotically (:math:`\lim N \to \infty`) to :math:`v` with strong guarantees
as long as there exists some constant :math:`\epsilon` such that [mengerson1996]_
.. math::
P(b) > 0 \implies \frac{P(b)}{Q(b)} \leq \epsilon.
Parameters
----------
proposal
The proposal distribution :math:`Q`.
func
mc_samples
density
The density :math:`P`. Does not have to be a probability distribution (can be
unnormalized).
burn_in
The number of samples in the chain discarded from the estimate, :math:`M`.
initial_sample
If specified, `initial_sample` is used as the value :math:`b^{(0)}` to start the
chain. Of size either ``proposal.batch_size + proposal.event_size`` or ``(1,) +
proposal.batch_size + proposal.event_size``. A :class:`ValueError` will be
thrown if any elements are outside the support of :math:`P` (`density`). If
unspecified, :math:`b^{(0)}` will be decided by randomly drawing from `proposal`
until all elements are in the support of `density`.
initial_sample_tries
If `initial_sample` is unspecified, `initial_sample_tries` dictates the
maximum number of draws from `proposal` allowed in order to find elements in
the support of `density` before a :class:`RuntimeError` is thrown.
is_log
Returns
-------
v : torch.Tensor
Warnings
--------
The resulting estimate has no gradient attached to it and therefore cannot be
backpropagated through.
"""
density: Density
initial_sample: Optional[torch.Tensor]
initial_sample_tries: int
burn_in: int
def __init__(
self,
proposal: torch.distributions.Distribution,
func: FunctionOnSample,
mc_samples: int,
density: Density,
burn_in: int = 0,
initial_sample: Optional[torch.Tensor] = None,
initial_sample_tries: int = 1000,
is_log: bool = False,
) -> None:
burn_in = argcheck.is_nonnegi(burn_in, "burn_in")
mc_samples = argcheck.is_posi(mc_samples, "mc_samples")
argcheck.is_lt(burn_in, mc_samples, "burn_in", "mc_samples")
if initial_sample is not None:
sample_shape = self.proposal.batch_shape + self.proposal.event_shape
if initial_sample.shape == sample_shape:
initial_sample = initial_sample.unsqueeze(0)
elif initial_sample.shape != (1,) + sample_shape:
raise ValueError(
f"Expected initial_sample to have shape {(1,) + sample_shape} or "
f"{sample_shape} "
)
if not torch.isfinite(self.density.log_prob(initial_sample)).all():
raise ValueError(
"all values in initial_sample must lie in the support of density"
)
elif initial_sample_tries < 1:
raise ValueError(
"initial_sample_tries must be positive when initial_sample is None"
)
super().__init__(proposal, func, mc_samples, is_log)
self.density, self.initial_sample = density, initial_sample
self.initial_sample_tries, self.burn_in = initial_sample_tries, burn_in
[docs]
def find_initial_sample(self, tries: Optional[int] = None) -> torch.Tensor:
"""Find an initial sample by randomly sampling from the proposal"""
if tries is None:
tries = self.initial_sample_tries
if tries < 1:
raise ValueError("tries must be positive")
with torch.no_grad():
sample = self.proposal.sample([1])
keep = torch.isfinite(self.density.log_prob(sample))
if keep.all():
return sample
for _ in range(tries - 1):
cur_sample = self.proposal.sample([1])
while keep.dim() < cur_sample.dim():
keep = keep.unsqueeze(-1) # event dims
sample = torch.where(keep, sample, cur_sample)
keep = torch.isfinite(self.density.log_prob(sample))
if keep.all():
return sample
raise RuntimeError(
f"Unable to find initial sample in {tries} draws. "
f"Either specify initial_sample on instantiation or increase "
f"initial_sample_tries."
)
def __call__(self) -> torch.Tensor:
with torch.no_grad():
if self.initial_sample is None:
last_sample = self.find_initial_sample()
else:
last_sample = self.initial_sample
v = None
num_kept = self.mc_samples - self.burn_in
last_ratio = self.density.log_prob(last_sample) - self.proposal.log_prob(
last_sample
)
uniform_draws = torch.rand(
(self.mc_samples,) + self.proposal.batch_shape,
device=last_sample.device,
).log()
for n in range(self.mc_samples):
cur_sample = self.proposal.sample([1])
cur_ratio = self.density.log_prob(cur_sample) - self.proposal.log_prob(
cur_sample
)
accept = (cur_ratio - last_ratio) > uniform_draws[n]
cur_ratio = accept * cur_ratio + (~accept) * last_ratio
while accept.dim() < cur_sample.dim():
accept = accept.unsqueeze(-1) # event dims
cur_sample = torch.where(accept, cur_sample, last_sample)
if n >= self.burn_in:
fb = self.func(cur_sample).squeeze(0)
if n == self.burn_in:
v = fb
elif self.is_log:
v = logaddexp(v, fb)
else:
v = v + fb
last_sample, last_ratio = cur_sample, cur_ratio
if self.is_log:
v -= math.log(num_kept)
else:
v /= num_kept
return v
def _attach_grad(x: torch.tensor, g: torch.Tensor):
def hook(grad):
hook.__handle.remove()
hook.__handle = None # dies if this hook is called again
return hook.__grad
hook.__grad = g
hook.__handle = x.register_hook(hook)
class _RebarControlVariate(torch.nn.Module):
__constants__ = ("func", "start_temp", "start_eta")
func: FunctionOnSample
start_temp: float
start_eta: float
def __init__(self, func, start_temp: float = 0.1, start_eta: float = 1.0) -> None:
start_temp = argcheck.is_posf(start_temp, "start_temp")
start_eta = argcheck.is_float(start_eta, "start_eta")
super().__init__()
self.func = func
self.start_temp = start_temp
self.start_eta = start_eta
self.log_temp = torch.nn.Parameter(torch.Tensor(1))
self.eta = torch.nn.Parameter(torch.Tensor(1))
self.reset_parameters()
def reset_parameters(self) -> None:
self.log_temp.data.fill_(self.start_temp).log_()
self.eta.data.fill_(self.start_eta)
def extra_repr(self) -> str:
return f"start_temp={self.start_temp},start_eta={self.start_eta}"
_REBAR_DOCS = """REBAR control variate for {dist} relaxation
REBAR [tucker2017]_ is a special case of the RELAX estimator [grathwohl2017]_ with
a control variate that passes a temperature-based transformation of the relaxed sample
to the function :math:`f` the expectation is being taken over. That is:
.. math::
c_{{\\lambda,\\eta}}(z) = \\eta f(\\sigma(z / \\lambda))
For the {dist} distribution, :math:`\\sigma` is the {sigma} function.
Parameters
----------
func
The function :math:`f`. Must be able to accept relaxed samples.
start_temp
The temperature the :math:`\\lambda` parameter is initialized to.
start_eta
The coefficient the :math:`\\eta` parameter is initialzied to.
Variables
---------
log_temp
A scalar initialized to ``log(start_temp)``.
eta
A scalar initialized to ``start_eta``.
Call Parameters
---------------
z : torch.Tensor
A tensor of shape ``{shape}`` representing the relaxed sample.
Returns
-------
z_temp : torch.Tensor
A tensor of the same shape as `z` storing the value :math:`c_{{\\lambda,\\eta}}(z)`.
Warnings
--------
This control variate can be traced but not scripted. Note that
:class:`pydrobert.torch.estimators.RelaxEstimator` is unable to be traced or scripted.
See Also
--------
pydrobert.torch.estimators.RelaxEstimator
For where to use this control variate.
"""
[docs]
class LogisticBernoulliRebarControlVariate(_RebarControlVariate):
__doc__ = _REBAR_DOCS.format(dist="LogisticBernoulli", sigma="sigmoid", shape="(*)")
def forward(self, z: torch.Tensor) -> torch.Tensor:
return self.eta * self.func((z / self.log_temp.exp()).sigmoid())
[docs]
class GumbelOneHotCategoricalRebarControlVariate(_RebarControlVariate):
__doc__ = _REBAR_DOCS.format(
dist="GumbelOneHotCategorical", sigma="softmax", shape="(*, V)"
)
def forward(self, z: torch.Tensor) -> torch.Tensor:
return self.eta * self.func((z / self.log_temp.exp()).softmax(-1))