pytorch
79 строк · 2.2 Кб
1# mypy: allow-untyped-defs
2from numbers import Number3
4import torch5from torch.distributions import constraints6from torch.distributions.exp_family import ExponentialFamily7from torch.distributions.utils import broadcast_all8
9
10__all__ = ["Poisson"]11
12
13class Poisson(ExponentialFamily):14r"""15Creates a Poisson distribution parameterized by :attr:`rate`, the rate parameter.
16
17Samples are nonnegative integers, with a pmf given by
18
19.. math::
20\mathrm{rate}^k \frac{e^{-\mathrm{rate}}}{k!}
21
22Example::
23
24>>> # xdoctest: +SKIP("poisson_cpu not implemented for 'Long'")
25>>> m = Poisson(torch.tensor([4]))
26>>> m.sample()
27tensor([ 3.])
28
29Args:
30rate (Number, Tensor): the rate parameter
31"""
32arg_constraints = {"rate": constraints.nonnegative}33support = constraints.nonnegative_integer34
35@property36def mean(self):37return self.rate38
39@property40def mode(self):41return self.rate.floor()42
43@property44def variance(self):45return self.rate46
47def __init__(self, rate, validate_args=None):48(self.rate,) = broadcast_all(rate)49if isinstance(rate, Number):50batch_shape = torch.Size()51else:52batch_shape = self.rate.size()53super().__init__(batch_shape, validate_args=validate_args)54
55def expand(self, batch_shape, _instance=None):56new = self._get_checked_instance(Poisson, _instance)57batch_shape = torch.Size(batch_shape)58new.rate = self.rate.expand(batch_shape)59super(Poisson, new).__init__(batch_shape, validate_args=False)60new._validate_args = self._validate_args61return new62
63def sample(self, sample_shape=torch.Size()):64shape = self._extended_shape(sample_shape)65with torch.no_grad():66return torch.poisson(self.rate.expand(shape))67
68def log_prob(self, value):69if self._validate_args:70self._validate_sample(value)71rate, value = broadcast_all(self.rate, value)72return value.xlogy(rate) - rate - (value + 1).lgamma()73
74@property75def _natural_params(self):76return (torch.log(self.rate),)77
78def _log_normalizer(self, x):79return torch.exp(x)80