pytorch
110 строк · 3.7 Кб
1# mypy: allow-untyped-defs
2from numbers import Number, Real3
4import torch5from torch.distributions import constraints6from torch.distributions.dirichlet import Dirichlet7from torch.distributions.exp_family import ExponentialFamily8from torch.distributions.utils import broadcast_all9from torch.types import _size10
11
12__all__ = ["Beta"]13
14
15class Beta(ExponentialFamily):16r"""17Beta distribution parameterized by :attr:`concentration1` and :attr:`concentration0`.
18
19Example::
20
21>>> # xdoctest: +IGNORE_WANT("non-deterministic")
22>>> m = Beta(torch.tensor([0.5]), torch.tensor([0.5]))
23>>> m.sample() # Beta distributed with concentration concentration1 and concentration0
24tensor([ 0.1046])
25
26Args:
27concentration1 (float or Tensor): 1st concentration parameter of the distribution
28(often referred to as alpha)
29concentration0 (float or Tensor): 2nd concentration parameter of the distribution
30(often referred to as beta)
31"""
32arg_constraints = {33"concentration1": constraints.positive,34"concentration0": constraints.positive,35}36support = constraints.unit_interval37has_rsample = True38
39def __init__(self, concentration1, concentration0, validate_args=None):40if isinstance(concentration1, Real) and isinstance(concentration0, Real):41concentration1_concentration0 = torch.tensor(42[float(concentration1), float(concentration0)]43)44else:45concentration1, concentration0 = broadcast_all(46concentration1, concentration047)48concentration1_concentration0 = torch.stack(49[concentration1, concentration0], -150)51self._dirichlet = Dirichlet(52concentration1_concentration0, validate_args=validate_args53)54super().__init__(self._dirichlet._batch_shape, validate_args=validate_args)55
56def expand(self, batch_shape, _instance=None):57new = self._get_checked_instance(Beta, _instance)58batch_shape = torch.Size(batch_shape)59new._dirichlet = self._dirichlet.expand(batch_shape)60super(Beta, new).__init__(batch_shape, validate_args=False)61new._validate_args = self._validate_args62return new63
64@property65def mean(self):66return self.concentration1 / (self.concentration1 + self.concentration0)67
68@property69def mode(self):70return self._dirichlet.mode[..., 0]71
72@property73def variance(self):74total = self.concentration1 + self.concentration075return self.concentration1 * self.concentration0 / (total.pow(2) * (total + 1))76
77def rsample(self, sample_shape: _size = ()) -> torch.Tensor:78return self._dirichlet.rsample(sample_shape).select(-1, 0)79
80def log_prob(self, value):81if self._validate_args:82self._validate_sample(value)83heads_tails = torch.stack([value, 1.0 - value], -1)84return self._dirichlet.log_prob(heads_tails)85
86def entropy(self):87return self._dirichlet.entropy()88
89@property90def concentration1(self):91result = self._dirichlet.concentration[..., 0]92if isinstance(result, Number):93return torch.tensor([result])94else:95return result96
97@property98def concentration0(self):99result = self._dirichlet.concentration[..., 1]100if isinstance(result, Number):101return torch.tensor([result])102else:103return result104
105@property106def _natural_params(self):107return (self.concentration1, self.concentration0)108
109def _log_normalizer(self, x, y):110return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y)111