pytorch
132 строки · 4.1 Кб
1# mypy: allow-untyped-defs
2from numbers import Number3
4import torch5from torch import nan6from torch.distributions import constraints7from torch.distributions.exp_family import ExponentialFamily8from torch.distributions.utils import (9broadcast_all,10lazy_property,11logits_to_probs,12probs_to_logits,13)
14from torch.nn.functional import binary_cross_entropy_with_logits15
16
17__all__ = ["Bernoulli"]18
19
20class Bernoulli(ExponentialFamily):21r"""22Creates a Bernoulli distribution parameterized by :attr:`probs`
23or :attr:`logits` (but not both).
24
25Samples are binary (0 or 1). They take the value `1` with probability `p`
26and `0` with probability `1 - p`.
27
28Example::
29
30>>> # xdoctest: +IGNORE_WANT("non-deterministic")
31>>> m = Bernoulli(torch.tensor([0.3]))
32>>> m.sample() # 30% chance 1; 70% chance 0
33tensor([ 0.])
34
35Args:
36probs (Number, Tensor): the probability of sampling `1`
37logits (Number, Tensor): the log-odds of sampling `1`
38"""
39arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}40support = constraints.boolean41has_enumerate_support = True42_mean_carrier_measure = 043
44def __init__(self, probs=None, logits=None, validate_args=None):45if (probs is None) == (logits is None):46raise ValueError(47"Either `probs` or `logits` must be specified, but not both."48)49if probs is not None:50is_scalar = isinstance(probs, Number)51(self.probs,) = broadcast_all(probs)52else:53is_scalar = isinstance(logits, Number)54(self.logits,) = broadcast_all(logits)55self._param = self.probs if probs is not None else self.logits56if is_scalar:57batch_shape = torch.Size()58else:59batch_shape = self._param.size()60super().__init__(batch_shape, validate_args=validate_args)61
62def expand(self, batch_shape, _instance=None):63new = self._get_checked_instance(Bernoulli, _instance)64batch_shape = torch.Size(batch_shape)65if "probs" in self.__dict__:66new.probs = self.probs.expand(batch_shape)67new._param = new.probs68if "logits" in self.__dict__:69new.logits = self.logits.expand(batch_shape)70new._param = new.logits71super(Bernoulli, new).__init__(batch_shape, validate_args=False)72new._validate_args = self._validate_args73return new74
75def _new(self, *args, **kwargs):76return self._param.new(*args, **kwargs)77
78@property79def mean(self):80return self.probs81
82@property83def mode(self):84mode = (self.probs >= 0.5).to(self.probs)85mode[self.probs == 0.5] = nan86return mode87
88@property89def variance(self):90return self.probs * (1 - self.probs)91
92@lazy_property93def logits(self):94return probs_to_logits(self.probs, is_binary=True)95
96@lazy_property97def probs(self):98return logits_to_probs(self.logits, is_binary=True)99
100@property101def param_shape(self):102return self._param.size()103
104def sample(self, sample_shape=torch.Size()):105shape = self._extended_shape(sample_shape)106with torch.no_grad():107return torch.bernoulli(self.probs.expand(shape))108
109def log_prob(self, value):110if self._validate_args:111self._validate_sample(value)112logits, value = broadcast_all(self.logits, value)113return -binary_cross_entropy_with_logits(logits, value, reduction="none")114
115def entropy(self):116return binary_cross_entropy_with_logits(117self.logits, self.probs, reduction="none"118)119
120def enumerate_support(self, expand=True):121values = torch.arange(2, dtype=self._param.dtype, device=self._param.device)122values = values.view((-1,) + (1,) * len(self._batch_shape))123if expand:124values = values.expand((-1,) + self._batch_shape)125return values126
127@property128def _natural_params(self):129return (torch.logit(self.probs),)130
131def _log_normalizer(self, x):132return torch.log1p(torch.exp(x))133