pytorch
209 строк · 5.9 Кб
1import math
2
3import torch
4import torch.jit
5from torch.distributions import constraints
6from torch.distributions.distribution import Distribution
7from torch.distributions.utils import broadcast_all, lazy_property
8
9__all__ = ["VonMises"]
10
11
12def _eval_poly(y, coef):
13coef = list(coef)
14result = coef.pop()
15while coef:
16result = coef.pop() + y * result
17return result
18
19
20_I0_COEF_SMALL = [
211.0,
223.5156229,
233.0899424,
241.2067492,
250.2659732,
260.360768e-1,
270.45813e-2,
28]
29_I0_COEF_LARGE = [
300.39894228,
310.1328592e-1,
320.225319e-2,
33-0.157565e-2,
340.916281e-2,
35-0.2057706e-1,
360.2635537e-1,
37-0.1647633e-1,
380.392377e-2,
39]
40_I1_COEF_SMALL = [
410.5,
420.87890594,
430.51498869,
440.15084934,
450.2658733e-1,
460.301532e-2,
470.32411e-3,
48]
49_I1_COEF_LARGE = [
500.39894228,
51-0.3988024e-1,
52-0.362018e-2,
530.163801e-2,
54-0.1031555e-1,
550.2282967e-1,
56-0.2895312e-1,
570.1787654e-1,
58-0.420059e-2,
59]
60
61_COEF_SMALL = [_I0_COEF_SMALL, _I1_COEF_SMALL]
62_COEF_LARGE = [_I0_COEF_LARGE, _I1_COEF_LARGE]
63
64
65def _log_modified_bessel_fn(x, order=0):
66"""
67Returns ``log(I_order(x))`` for ``x > 0``,
68where `order` is either 0 or 1.
69"""
70assert order == 0 or order == 1
71
72# compute small solution
73y = x / 3.75
74y = y * y
75small = _eval_poly(y, _COEF_SMALL[order])
76if order == 1:
77small = x.abs() * small
78small = small.log()
79
80# compute large solution
81y = 3.75 / x
82large = x - 0.5 * x.log() + _eval_poly(y, _COEF_LARGE[order]).log()
83
84result = torch.where(x < 3.75, small, large)
85return result
86
87
88@torch.jit.script_if_tracing
89def _rejection_sample(loc, concentration, proposal_r, x):
90done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device)
91while not done.all():
92u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device)
93u1, u2, u3 = u.unbind()
94z = torch.cos(math.pi * u1)
95f = (1 + proposal_r * z) / (proposal_r + z)
96c = concentration * (proposal_r - f)
97accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0)
98if accept.any():
99x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x)
100done = done | accept
101return (x + math.pi + loc) % (2 * math.pi) - math.pi
102
103
104class VonMises(Distribution):
105"""
106A circular von Mises distribution.
107
108This implementation uses polar coordinates. The ``loc`` and ``value`` args
109can be any real number (to facilitate unconstrained optimization), but are
110interpreted as angles modulo 2 pi.
111
112Example::
113>>> # xdoctest: +IGNORE_WANT("non-deterministic")
114>>> m = VonMises(torch.tensor([1.0]), torch.tensor([1.0]))
115>>> m.sample() # von Mises distributed with loc=1 and concentration=1
116tensor([1.9777])
117
118:param torch.Tensor loc: an angle in radians.
119:param torch.Tensor concentration: concentration parameter
120"""
121
122arg_constraints = {"loc": constraints.real, "concentration": constraints.positive}
123support = constraints.real
124has_rsample = False
125
126def __init__(self, loc, concentration, validate_args=None):
127self.loc, self.concentration = broadcast_all(loc, concentration)
128batch_shape = self.loc.shape
129event_shape = torch.Size()
130super().__init__(batch_shape, event_shape, validate_args)
131
132def log_prob(self, value):
133if self._validate_args:
134self._validate_sample(value)
135log_prob = self.concentration * torch.cos(value - self.loc)
136log_prob = (
137log_prob
138- math.log(2 * math.pi)
139- _log_modified_bessel_fn(self.concentration, order=0)
140)
141return log_prob
142
143@lazy_property
144def _loc(self):
145return self.loc.to(torch.double)
146
147@lazy_property
148def _concentration(self):
149return self.concentration.to(torch.double)
150
151@lazy_property
152def _proposal_r(self):
153kappa = self._concentration
154tau = 1 + (1 + 4 * kappa**2).sqrt()
155rho = (tau - (2 * tau).sqrt()) / (2 * kappa)
156_proposal_r = (1 + rho**2) / (2 * rho)
157# second order Taylor expansion around 0 for small kappa
158_proposal_r_taylor = 1 / kappa + kappa
159return torch.where(kappa < 1e-5, _proposal_r_taylor, _proposal_r)
160
161@torch.no_grad()
162def sample(self, sample_shape=torch.Size()):
163"""
164The sampling algorithm for the von Mises distribution is based on the
165following paper: D.J. Best and N.I. Fisher, "Efficient simulation of the
166von Mises distribution." Applied Statistics (1979): 152-157.
167
168Sampling is always done in double precision internally to avoid a hang
169in _rejection_sample() for small values of the concentration, which
170starts to happen for single precision around 1e-4 (see issue #88443).
171"""
172shape = self._extended_shape(sample_shape)
173x = torch.empty(shape, dtype=self._loc.dtype, device=self.loc.device)
174return _rejection_sample(
175self._loc, self._concentration, self._proposal_r, x
176).to(self.loc.dtype)
177
178def expand(self, batch_shape):
179try:
180return super().expand(batch_shape)
181except NotImplementedError:
182validate_args = self.__dict__.get("_validate_args")
183loc = self.loc.expand(batch_shape)
184concentration = self.concentration.expand(batch_shape)
185return type(self)(loc, concentration, validate_args=validate_args)
186
187@property
188def mean(self):
189"""
190The provided mean is the circular one.
191"""
192return self.loc
193
194@property
195def mode(self):
196return self.loc
197
198@lazy_property
199def variance(self):
200"""
201The provided variance is the circular one.
202"""
203return (
2041
205- (
206_log_modified_bessel_fn(self.concentration, order=1)
207- _log_modified_bessel_fn(self.concentration, order=0)
208).exp()
209)
210