pytorch
33 строки · 974.0 Байт
1from torch.distributions import constraints
2from torch.distributions.gamma import Gamma
3
4__all__ = ["Chi2"]
5
6
7class Chi2(Gamma):
8r"""
9Creates a Chi-squared distribution parameterized by shape parameter :attr:`df`.
10This is exactly equivalent to ``Gamma(alpha=0.5*df, beta=0.5)``
11
12Example::
13
14>>> # xdoctest: +IGNORE_WANT("non-deterministic")
15>>> m = Chi2(torch.tensor([1.0]))
16>>> m.sample() # Chi2 distributed with shape df=1
17tensor([ 0.1046])
18
19Args:
20df (float or Tensor): shape parameter of the distribution
21"""
22arg_constraints = {"df": constraints.positive}
23
24def __init__(self, df, validate_args=None):
25super().__init__(0.5 * df, 0.5, validate_args=validate_args)
26
27def expand(self, batch_shape, _instance=None):
28new = self._get_checked_instance(Chi2, _instance)
29return super().expand(batch_shape, new)
30
31@property
32def df(self):
33return self.concentration * 2
34