pytorch
165 строк · 5.8 Кб
1import torch2from torch.distributions import constraints3from torch.distributions.distribution import Distribution4from torch.distributions.utils import (5broadcast_all,6lazy_property,7logits_to_probs,8probs_to_logits,9)
10
11__all__ = ["Binomial"]12
13
14def _clamp_by_zero(x):15# works like clamp(x, min=0) but has grad at 0 is 0.516return (x.clamp(min=0) + x - x.clamp(max=0)) / 217
18
19class Binomial(Distribution):20r"""21Creates a Binomial distribution parameterized by :attr:`total_count` and
22either :attr:`probs` or :attr:`logits` (but not both). :attr:`total_count` must be
23broadcastable with :attr:`probs`/:attr:`logits`.
24
25Example::
26
27>>> # xdoctest: +IGNORE_WANT("non-deterministic")
28>>> m = Binomial(100, torch.tensor([0 , .2, .8, 1]))
29>>> x = m.sample()
30tensor([ 0., 22., 71., 100.])
31
32>>> m = Binomial(torch.tensor([[5.], [10.]]), torch.tensor([0.5, 0.8]))
33>>> x = m.sample()
34tensor([[ 4., 5.],
35[ 7., 6.]])
36
37Args:
38total_count (int or Tensor): number of Bernoulli trials
39probs (Tensor): Event probabilities
40logits (Tensor): Event log-odds
41"""
42arg_constraints = {43"total_count": constraints.nonnegative_integer,44"probs": constraints.unit_interval,45"logits": constraints.real,46}47has_enumerate_support = True48
49def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):50if (probs is None) == (logits is None):51raise ValueError(52"Either `probs` or `logits` must be specified, but not both."53)54if probs is not None:55(56self.total_count,57self.probs,58) = broadcast_all(total_count, probs)59self.total_count = self.total_count.type_as(self.probs)60else:61(62self.total_count,63self.logits,64) = broadcast_all(total_count, logits)65self.total_count = self.total_count.type_as(self.logits)66
67self._param = self.probs if probs is not None else self.logits68batch_shape = self._param.size()69super().__init__(batch_shape, validate_args=validate_args)70
71def expand(self, batch_shape, _instance=None):72new = self._get_checked_instance(Binomial, _instance)73batch_shape = torch.Size(batch_shape)74new.total_count = self.total_count.expand(batch_shape)75if "probs" in self.__dict__:76new.probs = self.probs.expand(batch_shape)77new._param = new.probs78if "logits" in self.__dict__:79new.logits = self.logits.expand(batch_shape)80new._param = new.logits81super(Binomial, new).__init__(batch_shape, validate_args=False)82new._validate_args = self._validate_args83return new84
85def _new(self, *args, **kwargs):86return self._param.new(*args, **kwargs)87
88@constraints.dependent_property(is_discrete=True, event_dim=0)89def support(self):90return constraints.integer_interval(0, self.total_count)91
92@property93def mean(self):94return self.total_count * self.probs95
96@property97def mode(self):98return ((self.total_count + 1) * self.probs).floor().clamp(max=self.total_count)99
100@property101def variance(self):102return self.total_count * self.probs * (1 - self.probs)103
104@lazy_property105def logits(self):106return probs_to_logits(self.probs, is_binary=True)107
108@lazy_property109def probs(self):110return logits_to_probs(self.logits, is_binary=True)111
112@property113def param_shape(self):114return self._param.size()115
116def sample(self, sample_shape=torch.Size()):117shape = self._extended_shape(sample_shape)118with torch.no_grad():119return torch.binomial(120self.total_count.expand(shape), self.probs.expand(shape)121)122
123def log_prob(self, value):124if self._validate_args:125self._validate_sample(value)126log_factorial_n = torch.lgamma(self.total_count + 1)127log_factorial_k = torch.lgamma(value + 1)128log_factorial_nmk = torch.lgamma(self.total_count - value + 1)129# k * log(p) + (n - k) * log(1 - p) = k * (log(p) - log(1 - p)) + n * log(1 - p)130# (case logit < 0) = k * logit - n * log1p(e^logit)131# (case logit > 0) = k * logit - n * (log(p) - log(1 - p)) + n * log(p)132# = k * logit - n * logit - n * log1p(e^-logit)133# (merge two cases) = k * logit - n * max(logit, 0) - n * log1p(e^-|logit|)134normalize_term = (135self.total_count * _clamp_by_zero(self.logits)136+ self.total_count * torch.log1p(torch.exp(-torch.abs(self.logits)))137- log_factorial_n138)139return (140value * self.logits - log_factorial_k - log_factorial_nmk - normalize_term141)142
143def entropy(self):144total_count = int(self.total_count.max())145if not self.total_count.min() == total_count:146raise NotImplementedError(147"Inhomogeneous total count not supported by `entropy`."148)149
150log_prob = self.log_prob(self.enumerate_support(False))151return -(torch.exp(log_prob) * log_prob).sum(0)152
153def enumerate_support(self, expand=True):154total_count = int(self.total_count.max())155if not self.total_count.min() == total_count:156raise NotImplementedError(157"Inhomogeneous total count not supported by `enumerate_support`."158)159values = torch.arange(1601 + total_count, dtype=self._param.dtype, device=self._param.device161)162values = values.view((-1,) + (1,) * len(self._batch_shape))163if expand:164values = values.expand((-1,) + self._batch_shape)165return values166