pytorch

Форк
0
/
fishersnedecor.py 
101 строка · 3.4 Кб
1
# mypy: allow-untyped-defs
2
from numbers import Number
3

4
import torch
5
from torch import nan
6
from torch.distributions import constraints
7
from torch.distributions.distribution import Distribution
8
from torch.distributions.gamma import Gamma
9
from torch.distributions.utils import broadcast_all
10
from torch.types import _size
11

12

13
__all__ = ["FisherSnedecor"]
14

15

16
class FisherSnedecor(Distribution):
17
    r"""
18
    Creates a Fisher-Snedecor distribution parameterized by :attr:`df1` and :attr:`df2`.
19

20
    Example::
21

22
        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
23
        >>> m = FisherSnedecor(torch.tensor([1.0]), torch.tensor([2.0]))
24
        >>> m.sample()  # Fisher-Snedecor-distributed with df1=1 and df2=2
25
        tensor([ 0.2453])
26

27
    Args:
28
        df1 (float or Tensor): degrees of freedom parameter 1
29
        df2 (float or Tensor): degrees of freedom parameter 2
30
    """
31
    arg_constraints = {"df1": constraints.positive, "df2": constraints.positive}
32
    support = constraints.positive
33
    has_rsample = True
34

35
    def __init__(self, df1, df2, validate_args=None):
36
        self.df1, self.df2 = broadcast_all(df1, df2)
37
        self._gamma1 = Gamma(self.df1 * 0.5, self.df1)
38
        self._gamma2 = Gamma(self.df2 * 0.5, self.df2)
39

40
        if isinstance(df1, Number) and isinstance(df2, Number):
41
            batch_shape = torch.Size()
42
        else:
43
            batch_shape = self.df1.size()
44
        super().__init__(batch_shape, validate_args=validate_args)
45

46
    def expand(self, batch_shape, _instance=None):
47
        new = self._get_checked_instance(FisherSnedecor, _instance)
48
        batch_shape = torch.Size(batch_shape)
49
        new.df1 = self.df1.expand(batch_shape)
50
        new.df2 = self.df2.expand(batch_shape)
51
        new._gamma1 = self._gamma1.expand(batch_shape)
52
        new._gamma2 = self._gamma2.expand(batch_shape)
53
        super(FisherSnedecor, new).__init__(batch_shape, validate_args=False)
54
        new._validate_args = self._validate_args
55
        return new
56

57
    @property
58
    def mean(self):
59
        df2 = self.df2.clone(memory_format=torch.contiguous_format)
60
        df2[df2 <= 2] = nan
61
        return df2 / (df2 - 2)
62

63
    @property
64
    def mode(self):
65
        mode = (self.df1 - 2) / self.df1 * self.df2 / (self.df2 + 2)
66
        mode[self.df1 <= 2] = nan
67
        return mode
68

69
    @property
70
    def variance(self):
71
        df2 = self.df2.clone(memory_format=torch.contiguous_format)
72
        df2[df2 <= 4] = nan
73
        return (
74
            2
75
            * df2.pow(2)
76
            * (self.df1 + df2 - 2)
77
            / (self.df1 * (df2 - 2).pow(2) * (df2 - 4))
78
        )
79

80
    def rsample(self, sample_shape: _size = torch.Size(())) -> torch.Tensor:
81
        shape = self._extended_shape(sample_shape)
82
        #   X1 ~ Gamma(df1 / 2, 1 / df1), X2 ~ Gamma(df2 / 2, 1 / df2)
83
        #   Y = df2 * df1 * X1 / (df1 * df2 * X2) = X1 / X2 ~ F(df1, df2)
84
        X1 = self._gamma1.rsample(sample_shape).view(shape)
85
        X2 = self._gamma2.rsample(sample_shape).view(shape)
86
        tiny = torch.finfo(X2.dtype).tiny
87
        X2.clamp_(min=tiny)
88
        Y = X1 / X2
89
        Y.clamp_(min=tiny)
90
        return Y
91

92
    def log_prob(self, value):
93
        if self._validate_args:
94
            self._validate_sample(value)
95
        ct1 = self.df1 * 0.5
96
        ct2 = self.df2 * 0.5
97
        ct3 = self.df1 / self.df2
98
        t1 = (ct1 + ct2).lgamma() - ct1.lgamma() - ct2.lgamma()
99
        t2 = ct1 * ct3.log() + (ct1 - 1) * torch.log(value)
100
        t3 = (ct1 + ct2) * torch.log1p(ct3 * value)
101
        return t1 + t2 - t3
102

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.