pytorch

Форк
0
/
exponential.py 
84 строки · 2.4 Кб
1
from numbers import Number
2

3
import torch
4
from torch.distributions import constraints
5
from torch.distributions.exp_family import ExponentialFamily
6
from torch.distributions.utils import broadcast_all
7

8
__all__ = ["Exponential"]
9

10

11
class Exponential(ExponentialFamily):
12
    r"""
13
    Creates a Exponential distribution parameterized by :attr:`rate`.
14

15
    Example::
16

17
        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
18
        >>> m = Exponential(torch.tensor([1.0]))
19
        >>> m.sample()  # Exponential distributed with rate=1
20
        tensor([ 0.1046])
21

22
    Args:
23
        rate (float or Tensor): rate = 1 / scale of the distribution
24
    """
25
    arg_constraints = {"rate": constraints.positive}
26
    support = constraints.nonnegative
27
    has_rsample = True
28
    _mean_carrier_measure = 0
29

30
    @property
31
    def mean(self):
32
        return self.rate.reciprocal()
33

34
    @property
35
    def mode(self):
36
        return torch.zeros_like(self.rate)
37

38
    @property
39
    def stddev(self):
40
        return self.rate.reciprocal()
41

42
    @property
43
    def variance(self):
44
        return self.rate.pow(-2)
45

46
    def __init__(self, rate, validate_args=None):
47
        (self.rate,) = broadcast_all(rate)
48
        batch_shape = torch.Size() if isinstance(rate, Number) else self.rate.size()
49
        super().__init__(batch_shape, validate_args=validate_args)
50

51
    def expand(self, batch_shape, _instance=None):
52
        new = self._get_checked_instance(Exponential, _instance)
53
        batch_shape = torch.Size(batch_shape)
54
        new.rate = self.rate.expand(batch_shape)
55
        super(Exponential, new).__init__(batch_shape, validate_args=False)
56
        new._validate_args = self._validate_args
57
        return new
58

59
    def rsample(self, sample_shape=torch.Size()):
60
        shape = self._extended_shape(sample_shape)
61
        return self.rate.new(shape).exponential_() / self.rate
62

63
    def log_prob(self, value):
64
        if self._validate_args:
65
            self._validate_sample(value)
66
        return self.rate.log() - self.rate * value
67

68
    def cdf(self, value):
69
        if self._validate_args:
70
            self._validate_sample(value)
71
        return 1 - torch.exp(-self.rate * value)
72

73
    def icdf(self, value):
74
        return -torch.log1p(-value) / self.rate
75

76
    def entropy(self):
77
        return 1.0 - torch.log(self.rate)
78

79
    @property
80
    def _natural_params(self):
81
        return (-self.rate,)
82

83
    def _log_normalizer(self, x):
84
        return -torch.log(-x)
85

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.