pytorch

Форк
0
77 строк · 2.2 Кб
1
from numbers import Number
2

3
import torch
4
from torch.distributions import constraints
5
from torch.distributions.exp_family import ExponentialFamily
6
from torch.distributions.utils import broadcast_all
7

8
__all__ = ["Poisson"]
9

10

11
class Poisson(ExponentialFamily):
12
    r"""
13
    Creates a Poisson distribution parameterized by :attr:`rate`, the rate parameter.
14

15
    Samples are nonnegative integers, with a pmf given by
16

17
    .. math::
18
      \mathrm{rate}^k \frac{e^{-\mathrm{rate}}}{k!}
19

20
    Example::
21

22
        >>> # xdoctest: +SKIP("poisson_cpu not implemented for 'Long'")
23
        >>> m = Poisson(torch.tensor([4]))
24
        >>> m.sample()
25
        tensor([ 3.])
26

27
    Args:
28
        rate (Number, Tensor): the rate parameter
29
    """
30
    arg_constraints = {"rate": constraints.nonnegative}
31
    support = constraints.nonnegative_integer
32

33
    @property
34
    def mean(self):
35
        return self.rate
36

37
    @property
38
    def mode(self):
39
        return self.rate.floor()
40

41
    @property
42
    def variance(self):
43
        return self.rate
44

45
    def __init__(self, rate, validate_args=None):
46
        (self.rate,) = broadcast_all(rate)
47
        if isinstance(rate, Number):
48
            batch_shape = torch.Size()
49
        else:
50
            batch_shape = self.rate.size()
51
        super().__init__(batch_shape, validate_args=validate_args)
52

53
    def expand(self, batch_shape, _instance=None):
54
        new = self._get_checked_instance(Poisson, _instance)
55
        batch_shape = torch.Size(batch_shape)
56
        new.rate = self.rate.expand(batch_shape)
57
        super(Poisson, new).__init__(batch_shape, validate_args=False)
58
        new._validate_args = self._validate_args
59
        return new
60

61
    def sample(self, sample_shape=torch.Size()):
62
        shape = self._extended_shape(sample_shape)
63
        with torch.no_grad():
64
            return torch.poisson(self.rate.expand(shape))
65

66
    def log_prob(self, value):
67
        if self._validate_args:
68
            self._validate_sample(value)
69
        rate, value = broadcast_all(self.rate, value)
70
        return value.xlogy(rate) - rate - (value + 1).lgamma()
71

72
    @property
73
    def _natural_params(self):
74
        return (torch.log(self.rate),)
75

76
    def _log_normalizer(self, x):
77
        return torch.exp(x)
78

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.