pytorch
64 строки · 2.3 Кб
1# mypy: allow-untyped-defs
2import torch3from torch.distributions.distribution import Distribution4
5
6__all__ = ["ExponentialFamily"]7
8
9class ExponentialFamily(Distribution):10r"""11ExponentialFamily is the abstract base class for probability distributions belonging to an
12exponential family, whose probability mass/density function has the form is defined below
13
14.. math::
15
16p_{F}(x; \theta) = \exp(\langle t(x), \theta\rangle - F(\theta) + k(x))
17
18where :math:`\theta` denotes the natural parameters, :math:`t(x)` denotes the sufficient statistic,
19:math:`F(\theta)` is the log normalizer function for a given family and :math:`k(x)` is the carrier
20measure.
21
22Note:
23This class is an intermediary between the `Distribution` class and distributions which belong
24to an exponential family mainly to check the correctness of the `.entropy()` and analytic KL
25divergence methods. We use this class to compute the entropy and KL divergence using the AD
26framework and Bregman divergences (courtesy of: Frank Nielsen and Richard Nock, Entropies and
27Cross-entropies of Exponential Families).
28"""
29
30@property31def _natural_params(self):32"""33Abstract method for natural parameters. Returns a tuple of Tensors based
34on the distribution
35"""
36raise NotImplementedError37
38def _log_normalizer(self, *natural_params):39"""40Abstract method for log normalizer function. Returns a log normalizer based on
41the distribution and input
42"""
43raise NotImplementedError44
45@property46def _mean_carrier_measure(self):47"""48Abstract method for expected carrier measure, which is required for computing
49entropy.
50"""
51raise NotImplementedError52
53def entropy(self):54"""55Method to compute the entropy using Bregman divergence of the log normalizer.
56"""
57result = -self._mean_carrier_measure58nparams = [p.detach().requires_grad_() for p in self._natural_params]59lg_normal = self._log_normalizer(*nparams)60gradients = torch.autograd.grad(lg_normal.sum(), nparams, create_graph=True)61result += lg_normal62for np, g in zip(nparams, gradients):63result -= (np * g).reshape(self._batch_shape + (-1,)).sum(-1)64return result65