pytorch

Форк
0
/
inverse_gamma.py 
80 строк · 2.4 Кб
1
import torch
2
from torch.distributions import constraints
3
from torch.distributions.gamma import Gamma
4
from torch.distributions.transformed_distribution import TransformedDistribution
5
from torch.distributions.transforms import PowerTransform
6

7

8
__all__ = ["InverseGamma"]
9

10

11
class InverseGamma(TransformedDistribution):
12
    r"""
13
    Creates an inverse gamma distribution parameterized by :attr:`concentration` and :attr:`rate`
14
    where::
15

16
        X ~ Gamma(concentration, rate)
17
        Y = 1 / X ~ InverseGamma(concentration, rate)
18

19
    Example::
20

21
        >>> # xdoctest: +IGNORE_WANT("non-deterinistic")
22
        >>> m = InverseGamma(torch.tensor([2.0]), torch.tensor([3.0]))
23
        >>> m.sample()
24
        tensor([ 1.2953])
25

26
    Args:
27
        concentration (float or Tensor): shape parameter of the distribution
28
            (often referred to as alpha)
29
        rate (float or Tensor): rate = 1 / scale of the distribution
30
            (often referred to as beta)
31
    """
32
    arg_constraints = {
33
        "concentration": constraints.positive,
34
        "rate": constraints.positive,
35
    }
36
    support = constraints.positive
37
    has_rsample = True
38

39
    def __init__(self, concentration, rate, validate_args=None):
40
        base_dist = Gamma(concentration, rate, validate_args=validate_args)
41
        neg_one = -base_dist.rate.new_ones(())
42
        super().__init__(
43
            base_dist, PowerTransform(neg_one), validate_args=validate_args
44
        )
45

46
    def expand(self, batch_shape, _instance=None):
47
        new = self._get_checked_instance(InverseGamma, _instance)
48
        return super().expand(batch_shape, _instance=new)
49

50
    @property
51
    def concentration(self):
52
        return self.base_dist.concentration
53

54
    @property
55
    def rate(self):
56
        return self.base_dist.rate
57

58
    @property
59
    def mean(self):
60
        result = self.rate / (self.concentration - 1)
61
        return torch.where(self.concentration > 1, result, torch.inf)
62

63
    @property
64
    def mode(self):
65
        return self.rate / (self.concentration + 1)
66

67
    @property
68
    def variance(self):
69
        result = self.rate.square() / (
70
            (self.concentration - 1).square() * (self.concentration - 2)
71
        )
72
        return torch.where(self.concentration > 2, result, torch.inf)
73

74
    def entropy(self):
75
        return (
76
            self.concentration
77
            + self.rate.log()
78
            + self.concentration.lgamma()
79
            - (1 + self.concentration) * self.concentration.digamma()
80
        )
81

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.