pytorch

Форк
0
165 строк · 5.8 Кб
1
import torch
2
from torch.distributions import constraints
3
from torch.distributions.distribution import Distribution
4
from torch.distributions.utils import (
5
    broadcast_all,
6
    lazy_property,
7
    logits_to_probs,
8
    probs_to_logits,
9
)
10

11
__all__ = ["Binomial"]
12

13

14
def _clamp_by_zero(x):
15
    # works like clamp(x, min=0) but has grad at 0 is 0.5
16
    return (x.clamp(min=0) + x - x.clamp(max=0)) / 2
17

18

19
class Binomial(Distribution):
20
    r"""
21
    Creates a Binomial distribution parameterized by :attr:`total_count` and
22
    either :attr:`probs` or :attr:`logits` (but not both). :attr:`total_count` must be
23
    broadcastable with :attr:`probs`/:attr:`logits`.
24

25
    Example::
26

27
        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
28
        >>> m = Binomial(100, torch.tensor([0 , .2, .8, 1]))
29
        >>> x = m.sample()
30
        tensor([   0.,   22.,   71.,  100.])
31

32
        >>> m = Binomial(torch.tensor([[5.], [10.]]), torch.tensor([0.5, 0.8]))
33
        >>> x = m.sample()
34
        tensor([[ 4.,  5.],
35
                [ 7.,  6.]])
36

37
    Args:
38
        total_count (int or Tensor): number of Bernoulli trials
39
        probs (Tensor): Event probabilities
40
        logits (Tensor): Event log-odds
41
    """
42
    arg_constraints = {
43
        "total_count": constraints.nonnegative_integer,
44
        "probs": constraints.unit_interval,
45
        "logits": constraints.real,
46
    }
47
    has_enumerate_support = True
48

49
    def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
50
        if (probs is None) == (logits is None):
51
            raise ValueError(
52
                "Either `probs` or `logits` must be specified, but not both."
53
            )
54
        if probs is not None:
55
            (
56
                self.total_count,
57
                self.probs,
58
            ) = broadcast_all(total_count, probs)
59
            self.total_count = self.total_count.type_as(self.probs)
60
        else:
61
            (
62
                self.total_count,
63
                self.logits,
64
            ) = broadcast_all(total_count, logits)
65
            self.total_count = self.total_count.type_as(self.logits)
66

67
        self._param = self.probs if probs is not None else self.logits
68
        batch_shape = self._param.size()
69
        super().__init__(batch_shape, validate_args=validate_args)
70

71
    def expand(self, batch_shape, _instance=None):
72
        new = self._get_checked_instance(Binomial, _instance)
73
        batch_shape = torch.Size(batch_shape)
74
        new.total_count = self.total_count.expand(batch_shape)
75
        if "probs" in self.__dict__:
76
            new.probs = self.probs.expand(batch_shape)
77
            new._param = new.probs
78
        if "logits" in self.__dict__:
79
            new.logits = self.logits.expand(batch_shape)
80
            new._param = new.logits
81
        super(Binomial, new).__init__(batch_shape, validate_args=False)
82
        new._validate_args = self._validate_args
83
        return new
84

85
    def _new(self, *args, **kwargs):
86
        return self._param.new(*args, **kwargs)
87

88
    @constraints.dependent_property(is_discrete=True, event_dim=0)
89
    def support(self):
90
        return constraints.integer_interval(0, self.total_count)
91

92
    @property
93
    def mean(self):
94
        return self.total_count * self.probs
95

96
    @property
97
    def mode(self):
98
        return ((self.total_count + 1) * self.probs).floor().clamp(max=self.total_count)
99

100
    @property
101
    def variance(self):
102
        return self.total_count * self.probs * (1 - self.probs)
103

104
    @lazy_property
105
    def logits(self):
106
        return probs_to_logits(self.probs, is_binary=True)
107

108
    @lazy_property
109
    def probs(self):
110
        return logits_to_probs(self.logits, is_binary=True)
111

112
    @property
113
    def param_shape(self):
114
        return self._param.size()
115

116
    def sample(self, sample_shape=torch.Size()):
117
        shape = self._extended_shape(sample_shape)
118
        with torch.no_grad():
119
            return torch.binomial(
120
                self.total_count.expand(shape), self.probs.expand(shape)
121
            )
122

123
    def log_prob(self, value):
124
        if self._validate_args:
125
            self._validate_sample(value)
126
        log_factorial_n = torch.lgamma(self.total_count + 1)
127
        log_factorial_k = torch.lgamma(value + 1)
128
        log_factorial_nmk = torch.lgamma(self.total_count - value + 1)
129
        # k * log(p) + (n - k) * log(1 - p) = k * (log(p) - log(1 - p)) + n * log(1 - p)
130
        #     (case logit < 0)              = k * logit - n * log1p(e^logit)
131
        #     (case logit > 0)              = k * logit - n * (log(p) - log(1 - p)) + n * log(p)
132
        #                                   = k * logit - n * logit - n * log1p(e^-logit)
133
        #     (merge two cases)             = k * logit - n * max(logit, 0) - n * log1p(e^-|logit|)
134
        normalize_term = (
135
            self.total_count * _clamp_by_zero(self.logits)
136
            + self.total_count * torch.log1p(torch.exp(-torch.abs(self.logits)))
137
            - log_factorial_n
138
        )
139
        return (
140
            value * self.logits - log_factorial_k - log_factorial_nmk - normalize_term
141
        )
142

143
    def entropy(self):
144
        total_count = int(self.total_count.max())
145
        if not self.total_count.min() == total_count:
146
            raise NotImplementedError(
147
                "Inhomogeneous total count not supported by `entropy`."
148
            )
149

150
        log_prob = self.log_prob(self.enumerate_support(False))
151
        return -(torch.exp(log_prob) * log_prob).sum(0)
152

153
    def enumerate_support(self, expand=True):
154
        total_count = int(self.total_count.max())
155
        if not self.total_count.min() == total_count:
156
            raise NotImplementedError(
157
                "Inhomogeneous total count not supported by `enumerate_support`."
158
            )
159
        values = torch.arange(
160
            1 + total_count, dtype=self._param.dtype, device=self._param.device
161
        )
162
        values = values.view((-1,) + (1,) * len(self._batch_shape))
163
        if expand:
164
            values = values.expand((-1,) + self._batch_shape)
165
        return values
166

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.