pytorch

Форк
0
/
multinomial.py 
135 строк · 5.3 Кб
1
import torch
2
from torch import inf
3
from torch.distributions import Categorical, constraints
4
from torch.distributions.binomial import Binomial
5
from torch.distributions.distribution import Distribution
6
from torch.distributions.utils import broadcast_all
7

8
__all__ = ["Multinomial"]
9

10

11
class Multinomial(Distribution):
12
    r"""
13
    Creates a Multinomial distribution parameterized by :attr:`total_count` and
14
    either :attr:`probs` or :attr:`logits` (but not both). The innermost dimension of
15
    :attr:`probs` indexes over categories. All other dimensions index over batches.
16

17
    Note that :attr:`total_count` need not be specified if only :meth:`log_prob` is
18
    called (see example below)
19

20
    .. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,
21
              and it will be normalized to sum to 1 along the last dimension. :attr:`probs`
22
              will return this normalized value.
23
              The `logits` argument will be interpreted as unnormalized log probabilities
24
              and can therefore be any real number. It will likewise be normalized so that
25
              the resulting probabilities sum to 1 along the last dimension. :attr:`logits`
26
              will return this normalized value.
27

28
    -   :meth:`sample` requires a single shared `total_count` for all
29
        parameters and samples.
30
    -   :meth:`log_prob` allows different `total_count` for each parameter and
31
        sample.
32

33
    Example::
34

35
        >>> # xdoctest: +SKIP("FIXME: found invalid values")
36
        >>> m = Multinomial(100, torch.tensor([ 1., 1., 1., 1.]))
37
        >>> x = m.sample()  # equal probability of 0, 1, 2, 3
38
        tensor([ 21.,  24.,  30.,  25.])
39

40
        >>> Multinomial(probs=torch.tensor([1., 1., 1., 1.])).log_prob(x)
41
        tensor([-4.1338])
42

43
    Args:
44
        total_count (int): number of trials
45
        probs (Tensor): event probabilities
46
        logits (Tensor): event log probabilities (unnormalized)
47
    """
48
    arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
49
    total_count: int
50

51
    @property
52
    def mean(self):
53
        return self.probs * self.total_count
54

55
    @property
56
    def variance(self):
57
        return self.total_count * self.probs * (1 - self.probs)
58

59
    def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
60
        if not isinstance(total_count, int):
61
            raise NotImplementedError("inhomogeneous total_count is not supported")
62
        self.total_count = total_count
63
        self._categorical = Categorical(probs=probs, logits=logits)
64
        self._binomial = Binomial(total_count=total_count, probs=self.probs)
65
        batch_shape = self._categorical.batch_shape
66
        event_shape = self._categorical.param_shape[-1:]
67
        super().__init__(batch_shape, event_shape, validate_args=validate_args)
68

69
    def expand(self, batch_shape, _instance=None):
70
        new = self._get_checked_instance(Multinomial, _instance)
71
        batch_shape = torch.Size(batch_shape)
72
        new.total_count = self.total_count
73
        new._categorical = self._categorical.expand(batch_shape)
74
        super(Multinomial, new).__init__(
75
            batch_shape, self.event_shape, validate_args=False
76
        )
77
        new._validate_args = self._validate_args
78
        return new
79

80
    def _new(self, *args, **kwargs):
81
        return self._categorical._new(*args, **kwargs)
82

83
    @constraints.dependent_property(is_discrete=True, event_dim=1)
84
    def support(self):
85
        return constraints.multinomial(self.total_count)
86

87
    @property
88
    def logits(self):
89
        return self._categorical.logits
90

91
    @property
92
    def probs(self):
93
        return self._categorical.probs
94

95
    @property
96
    def param_shape(self):
97
        return self._categorical.param_shape
98

99
    def sample(self, sample_shape=torch.Size()):
100
        sample_shape = torch.Size(sample_shape)
101
        samples = self._categorical.sample(
102
            torch.Size((self.total_count,)) + sample_shape
103
        )
104
        # samples.shape is (total_count, sample_shape, batch_shape), need to change it to
105
        # (sample_shape, batch_shape, total_count)
106
        shifted_idx = list(range(samples.dim()))
107
        shifted_idx.append(shifted_idx.pop(0))
108
        samples = samples.permute(*shifted_idx)
109
        counts = samples.new(self._extended_shape(sample_shape)).zero_()
110
        counts.scatter_add_(-1, samples, torch.ones_like(samples))
111
        return counts.type_as(self.probs)
112

113
    def entropy(self):
114
        n = torch.tensor(self.total_count)
115

116
        cat_entropy = self._categorical.entropy()
117
        term1 = n * cat_entropy - torch.lgamma(n + 1)
118

119
        support = self._binomial.enumerate_support(expand=False)[1:]
120
        binomial_probs = torch.exp(self._binomial.log_prob(support))
121
        weights = torch.lgamma(support + 1)
122
        term2 = (binomial_probs * weights).sum([0, -1])
123

124
        return term1 + term2
125

126
    def log_prob(self, value):
127
        if self._validate_args:
128
            self._validate_sample(value)
129
        logits, value = broadcast_all(self.logits, value)
130
        logits = logits.clone(memory_format=torch.contiguous_format)
131
        log_factorial_n = torch.lgamma(value.sum(-1) + 1)
132
        log_factorial_xs = torch.lgamma(value + 1).sum(-1)
133
        logits[(value == 0) & (logits == -inf)] = 0
134
        log_powers = (logits * value).sum(-1)
135
        return log_factorial_n - log_factorial_xs + log_powers
136

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.