pytorch
87 строк · 2.5 Кб
1# mypy: allow-untyped-defs
2from numbers import Number3
4import torch5from torch.distributions import constraints6from torch.distributions.exp_family import ExponentialFamily7from torch.distributions.utils import broadcast_all8from torch.types import _size9
10
11__all__ = ["Exponential"]12
13
14class Exponential(ExponentialFamily):15r"""16Creates a Exponential distribution parameterized by :attr:`rate`.
17
18Example::
19
20>>> # xdoctest: +IGNORE_WANT("non-deterministic")
21>>> m = Exponential(torch.tensor([1.0]))
22>>> m.sample() # Exponential distributed with rate=1
23tensor([ 0.1046])
24
25Args:
26rate (float or Tensor): rate = 1 / scale of the distribution
27"""
28arg_constraints = {"rate": constraints.positive}29support = constraints.nonnegative30has_rsample = True31_mean_carrier_measure = 032
33@property34def mean(self):35return self.rate.reciprocal()36
37@property38def mode(self):39return torch.zeros_like(self.rate)40
41@property42def stddev(self):43return self.rate.reciprocal()44
45@property46def variance(self):47return self.rate.pow(-2)48
49def __init__(self, rate, validate_args=None):50(self.rate,) = broadcast_all(rate)51batch_shape = torch.Size() if isinstance(rate, Number) else self.rate.size()52super().__init__(batch_shape, validate_args=validate_args)53
54def expand(self, batch_shape, _instance=None):55new = self._get_checked_instance(Exponential, _instance)56batch_shape = torch.Size(batch_shape)57new.rate = self.rate.expand(batch_shape)58super(Exponential, new).__init__(batch_shape, validate_args=False)59new._validate_args = self._validate_args60return new61
62def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor:63shape = self._extended_shape(sample_shape)64return self.rate.new(shape).exponential_() / self.rate65
66def log_prob(self, value):67if self._validate_args:68self._validate_sample(value)69return self.rate.log() - self.rate * value70
71def cdf(self, value):72if self._validate_args:73self._validate_sample(value)74return 1 - torch.exp(-self.rate * value)75
76def icdf(self, value):77return -torch.log1p(-value) / self.rate78
79def entropy(self):80return 1.0 - torch.log(self.rate)81
82@property83def _natural_params(self):84return (-self.rate,)85
86def _log_normalizer(self, x):87return -torch.log(-x)88