pytorch

Форк
0
/
von_mises.py 
209 строк · 5.9 Кб
1
import math
2

3
import torch
4
import torch.jit
5
from torch.distributions import constraints
6
from torch.distributions.distribution import Distribution
7
from torch.distributions.utils import broadcast_all, lazy_property
8

9
__all__ = ["VonMises"]
10

11

12
def _eval_poly(y, coef):
13
    coef = list(coef)
14
    result = coef.pop()
15
    while coef:
16
        result = coef.pop() + y * result
17
    return result
18

19

20
_I0_COEF_SMALL = [
21
    1.0,
22
    3.5156229,
23
    3.0899424,
24
    1.2067492,
25
    0.2659732,
26
    0.360768e-1,
27
    0.45813e-2,
28
]
29
_I0_COEF_LARGE = [
30
    0.39894228,
31
    0.1328592e-1,
32
    0.225319e-2,
33
    -0.157565e-2,
34
    0.916281e-2,
35
    -0.2057706e-1,
36
    0.2635537e-1,
37
    -0.1647633e-1,
38
    0.392377e-2,
39
]
40
_I1_COEF_SMALL = [
41
    0.5,
42
    0.87890594,
43
    0.51498869,
44
    0.15084934,
45
    0.2658733e-1,
46
    0.301532e-2,
47
    0.32411e-3,
48
]
49
_I1_COEF_LARGE = [
50
    0.39894228,
51
    -0.3988024e-1,
52
    -0.362018e-2,
53
    0.163801e-2,
54
    -0.1031555e-1,
55
    0.2282967e-1,
56
    -0.2895312e-1,
57
    0.1787654e-1,
58
    -0.420059e-2,
59
]
60

61
_COEF_SMALL = [_I0_COEF_SMALL, _I1_COEF_SMALL]
62
_COEF_LARGE = [_I0_COEF_LARGE, _I1_COEF_LARGE]
63

64

65
def _log_modified_bessel_fn(x, order=0):
66
    """
67
    Returns ``log(I_order(x))`` for ``x > 0``,
68
    where `order` is either 0 or 1.
69
    """
70
    assert order == 0 or order == 1
71

72
    # compute small solution
73
    y = x / 3.75
74
    y = y * y
75
    small = _eval_poly(y, _COEF_SMALL[order])
76
    if order == 1:
77
        small = x.abs() * small
78
    small = small.log()
79

80
    # compute large solution
81
    y = 3.75 / x
82
    large = x - 0.5 * x.log() + _eval_poly(y, _COEF_LARGE[order]).log()
83

84
    result = torch.where(x < 3.75, small, large)
85
    return result
86

87

88
@torch.jit.script_if_tracing
89
def _rejection_sample(loc, concentration, proposal_r, x):
90
    done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device)
91
    while not done.all():
92
        u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device)
93
        u1, u2, u3 = u.unbind()
94
        z = torch.cos(math.pi * u1)
95
        f = (1 + proposal_r * z) / (proposal_r + z)
96
        c = concentration * (proposal_r - f)
97
        accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0)
98
        if accept.any():
99
            x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x)
100
            done = done | accept
101
    return (x + math.pi + loc) % (2 * math.pi) - math.pi
102

103

104
class VonMises(Distribution):
105
    """
106
    A circular von Mises distribution.
107

108
    This implementation uses polar coordinates. The ``loc`` and ``value`` args
109
    can be any real number (to facilitate unconstrained optimization), but are
110
    interpreted as angles modulo 2 pi.
111

112
    Example::
113
        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
114
        >>> m = VonMises(torch.tensor([1.0]), torch.tensor([1.0]))
115
        >>> m.sample()  # von Mises distributed with loc=1 and concentration=1
116
        tensor([1.9777])
117

118
    :param torch.Tensor loc: an angle in radians.
119
    :param torch.Tensor concentration: concentration parameter
120
    """
121

122
    arg_constraints = {"loc": constraints.real, "concentration": constraints.positive}
123
    support = constraints.real
124
    has_rsample = False
125

126
    def __init__(self, loc, concentration, validate_args=None):
127
        self.loc, self.concentration = broadcast_all(loc, concentration)
128
        batch_shape = self.loc.shape
129
        event_shape = torch.Size()
130
        super().__init__(batch_shape, event_shape, validate_args)
131

132
    def log_prob(self, value):
133
        if self._validate_args:
134
            self._validate_sample(value)
135
        log_prob = self.concentration * torch.cos(value - self.loc)
136
        log_prob = (
137
            log_prob
138
            - math.log(2 * math.pi)
139
            - _log_modified_bessel_fn(self.concentration, order=0)
140
        )
141
        return log_prob
142

143
    @lazy_property
144
    def _loc(self):
145
        return self.loc.to(torch.double)
146

147
    @lazy_property
148
    def _concentration(self):
149
        return self.concentration.to(torch.double)
150

151
    @lazy_property
152
    def _proposal_r(self):
153
        kappa = self._concentration
154
        tau = 1 + (1 + 4 * kappa**2).sqrt()
155
        rho = (tau - (2 * tau).sqrt()) / (2 * kappa)
156
        _proposal_r = (1 + rho**2) / (2 * rho)
157
        # second order Taylor expansion around 0 for small kappa
158
        _proposal_r_taylor = 1 / kappa + kappa
159
        return torch.where(kappa < 1e-5, _proposal_r_taylor, _proposal_r)
160

161
    @torch.no_grad()
162
    def sample(self, sample_shape=torch.Size()):
163
        """
164
        The sampling algorithm for the von Mises distribution is based on the
165
        following paper: D.J. Best and N.I. Fisher, "Efficient simulation of the
166
        von Mises distribution." Applied Statistics (1979): 152-157.
167

168
        Sampling is always done in double precision internally to avoid a hang
169
        in _rejection_sample() for small values of the concentration, which
170
        starts to happen for single precision around 1e-4 (see issue #88443).
171
        """
172
        shape = self._extended_shape(sample_shape)
173
        x = torch.empty(shape, dtype=self._loc.dtype, device=self.loc.device)
174
        return _rejection_sample(
175
            self._loc, self._concentration, self._proposal_r, x
176
        ).to(self.loc.dtype)
177

178
    def expand(self, batch_shape):
179
        try:
180
            return super().expand(batch_shape)
181
        except NotImplementedError:
182
            validate_args = self.__dict__.get("_validate_args")
183
            loc = self.loc.expand(batch_shape)
184
            concentration = self.concentration.expand(batch_shape)
185
            return type(self)(loc, concentration, validate_args=validate_args)
186

187
    @property
188
    def mean(self):
189
        """
190
        The provided mean is the circular one.
191
        """
192
        return self.loc
193

194
    @property
195
    def mode(self):
196
        return self.loc
197

198
    @lazy_property
199
    def variance(self):
200
        """
201
        The provided variance is the circular one.
202
        """
203
        return (
204
            1
205
            - (
206
                _log_modified_bessel_fn(self.concentration, order=1)
207
                - _log_modified_bessel_fn(self.concentration, order=0)
208
            ).exp()
209
        )
210

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.