pytorch
107 строк · 3.6 Кб
1from numbers import Number, Real
2
3import torch
4from torch.distributions import constraints
5from torch.distributions.dirichlet import Dirichlet
6from torch.distributions.exp_family import ExponentialFamily
7from torch.distributions.utils import broadcast_all
8
9__all__ = ["Beta"]
10
11
12class Beta(ExponentialFamily):
13r"""
14Beta distribution parameterized by :attr:`concentration1` and :attr:`concentration0`.
15
16Example::
17
18>>> # xdoctest: +IGNORE_WANT("non-deterministic")
19>>> m = Beta(torch.tensor([0.5]), torch.tensor([0.5]))
20>>> m.sample() # Beta distributed with concentration concentration1 and concentration0
21tensor([ 0.1046])
22
23Args:
24concentration1 (float or Tensor): 1st concentration parameter of the distribution
25(often referred to as alpha)
26concentration0 (float or Tensor): 2nd concentration parameter of the distribution
27(often referred to as beta)
28"""
29arg_constraints = {
30"concentration1": constraints.positive,
31"concentration0": constraints.positive,
32}
33support = constraints.unit_interval
34has_rsample = True
35
36def __init__(self, concentration1, concentration0, validate_args=None):
37if isinstance(concentration1, Real) and isinstance(concentration0, Real):
38concentration1_concentration0 = torch.tensor(
39[float(concentration1), float(concentration0)]
40)
41else:
42concentration1, concentration0 = broadcast_all(
43concentration1, concentration0
44)
45concentration1_concentration0 = torch.stack(
46[concentration1, concentration0], -1
47)
48self._dirichlet = Dirichlet(
49concentration1_concentration0, validate_args=validate_args
50)
51super().__init__(self._dirichlet._batch_shape, validate_args=validate_args)
52
53def expand(self, batch_shape, _instance=None):
54new = self._get_checked_instance(Beta, _instance)
55batch_shape = torch.Size(batch_shape)
56new._dirichlet = self._dirichlet.expand(batch_shape)
57super(Beta, new).__init__(batch_shape, validate_args=False)
58new._validate_args = self._validate_args
59return new
60
61@property
62def mean(self):
63return self.concentration1 / (self.concentration1 + self.concentration0)
64
65@property
66def mode(self):
67return self._dirichlet.mode[..., 0]
68
69@property
70def variance(self):
71total = self.concentration1 + self.concentration0
72return self.concentration1 * self.concentration0 / (total.pow(2) * (total + 1))
73
74def rsample(self, sample_shape=()):
75return self._dirichlet.rsample(sample_shape).select(-1, 0)
76
77def log_prob(self, value):
78if self._validate_args:
79self._validate_sample(value)
80heads_tails = torch.stack([value, 1.0 - value], -1)
81return self._dirichlet.log_prob(heads_tails)
82
83def entropy(self):
84return self._dirichlet.entropy()
85
86@property
87def concentration1(self):
88result = self._dirichlet.concentration[..., 0]
89if isinstance(result, Number):
90return torch.tensor([result])
91else:
92return result
93
94@property
95def concentration0(self):
96result = self._dirichlet.concentration[..., 1]
97if isinstance(result, Number):
98return torch.tensor([result])
99else:
100return result
101
102@property
103def _natural_params(self):
104return (self.concentration1, self.concentration0)
105
106def _log_normalizer(self, x, y):
107return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y)
108