pytorch

Форк
0
/
exp_family.py 
64 строки · 2.3 Кб
1
# mypy: allow-untyped-defs
2
import torch
3
from torch.distributions.distribution import Distribution
4

5

6
__all__ = ["ExponentialFamily"]
7

8

9
class ExponentialFamily(Distribution):
10
    r"""
11
    ExponentialFamily is the abstract base class for probability distributions belonging to an
12
    exponential family, whose probability mass/density function has the form is defined below
13

14
    .. math::
15

16
        p_{F}(x; \theta) = \exp(\langle t(x), \theta\rangle - F(\theta) + k(x))
17

18
    where :math:`\theta` denotes the natural parameters, :math:`t(x)` denotes the sufficient statistic,
19
    :math:`F(\theta)` is the log normalizer function for a given family and :math:`k(x)` is the carrier
20
    measure.
21

22
    Note:
23
        This class is an intermediary between the `Distribution` class and distributions which belong
24
        to an exponential family mainly to check the correctness of the `.entropy()` and analytic KL
25
        divergence methods. We use this class to compute the entropy and KL divergence using the AD
26
        framework and Bregman divergences (courtesy of: Frank Nielsen and Richard Nock, Entropies and
27
        Cross-entropies of Exponential Families).
28
    """
29

30
    @property
31
    def _natural_params(self):
32
        """
33
        Abstract method for natural parameters. Returns a tuple of Tensors based
34
        on the distribution
35
        """
36
        raise NotImplementedError
37

38
    def _log_normalizer(self, *natural_params):
39
        """
40
        Abstract method for log normalizer function. Returns a log normalizer based on
41
        the distribution and input
42
        """
43
        raise NotImplementedError
44

45
    @property
46
    def _mean_carrier_measure(self):
47
        """
48
        Abstract method for expected carrier measure, which is required for computing
49
        entropy.
50
        """
51
        raise NotImplementedError
52

53
    def entropy(self):
54
        """
55
        Method to compute the entropy using Bregman divergence of the log normalizer.
56
        """
57
        result = -self._mean_carrier_measure
58
        nparams = [p.detach().requires_grad_() for p in self._natural_params]
59
        lg_normal = self._log_normalizer(*nparams)
60
        gradients = torch.autograd.grad(lg_normal.sum(), nparams, create_graph=True)
61
        result += lg_normal
62
        for np, g in zip(nparams, gradients):
63
            result -= (np * g).reshape(self._batch_shape + (-1,)).sum(-1)
64
        return result
65

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.