pytorch

Форк
0
339 строк · 13.4 Кб
1
# mypy: allow-untyped-defs
2
import math
3
import warnings
4
from numbers import Number
5
from typing import Optional, Union
6

7
import torch
8
from torch import nan
9
from torch.distributions import constraints
10
from torch.distributions.exp_family import ExponentialFamily
11
from torch.distributions.multivariate_normal import _precision_to_scale_tril
12
from torch.distributions.utils import lazy_property
13
from torch.types import _size
14

15

16
__all__ = ["Wishart"]
17

18
_log_2 = math.log(2)
19

20

21
def _mvdigamma(x: torch.Tensor, p: int) -> torch.Tensor:
22
    assert x.gt((p - 1) / 2).all(), "Wrong domain for multivariate digamma function."
23
    return torch.digamma(
24
        x.unsqueeze(-1)
25
        - torch.arange(p, dtype=x.dtype, device=x.device).div(2).expand(x.shape + (-1,))
26
    ).sum(-1)
27

28

29
def _clamp_above_eps(x: torch.Tensor) -> torch.Tensor:
30
    # We assume positive input for this function
31
    return x.clamp(min=torch.finfo(x.dtype).eps)
32

33

34
class Wishart(ExponentialFamily):
35
    r"""
36
    Creates a Wishart distribution parameterized by a symmetric positive definite matrix :math:`\Sigma`,
37
    or its Cholesky decomposition :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`
38

39
    Example:
40
        >>> # xdoctest: +SKIP("FIXME: scale_tril must be at least two-dimensional")
41
        >>> m = Wishart(torch.Tensor([2]), covariance_matrix=torch.eye(2))
42
        >>> m.sample()  # Wishart distributed with mean=`df * I` and
43
        >>>             # variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j
44

45
    Args:
46
        df (float or Tensor): real-valued parameter larger than the (dimension of Square matrix) - 1
47
        covariance_matrix (Tensor): positive-definite covariance matrix
48
        precision_matrix (Tensor): positive-definite precision matrix
49
        scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal
50
    Note:
51
        Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or
52
        :attr:`scale_tril` can be specified.
53
        Using :attr:`scale_tril` will be more efficient: all computations internally
54
        are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or
55
        :attr:`precision_matrix` is passed instead, it is only used to compute
56
        the corresponding lower triangular matrices using a Cholesky decomposition.
57
        'torch.distributions.LKJCholesky' is a restricted Wishart distribution.[1]
58

59
    **References**
60

61
    [1] Wang, Z., Wu, Y. and Chu, H., 2018. `On equivalence of the LKJ distribution and the restricted Wishart distribution`.
62
    [2] Sawyer, S., 2007. `Wishart Distributions and Inverse-Wishart Sampling`.
63
    [3] Anderson, T. W., 2003. `An Introduction to Multivariate Statistical Analysis (3rd ed.)`.
64
    [4] Odell, P. L. & Feiveson, A. H., 1966. `A Numerical Procedure to Generate a SampleCovariance Matrix`. JASA, 61(313):199-203.
65
    [5] Ku, Y.-C. & Bloomfield, P., 2010. `Generating Random Wishart Matrices with Fractional Degrees of Freedom in OX`.
66
    """
67
    arg_constraints = {
68
        "covariance_matrix": constraints.positive_definite,
69
        "precision_matrix": constraints.positive_definite,
70
        "scale_tril": constraints.lower_cholesky,
71
        "df": constraints.greater_than(0),
72
    }
73
    support = constraints.positive_definite
74
    has_rsample = True
75
    _mean_carrier_measure = 0
76

77
    def __init__(
78
        self,
79
        df: Union[torch.Tensor, Number],
80
        covariance_matrix: Optional[torch.Tensor] = None,
81
        precision_matrix: Optional[torch.Tensor] = None,
82
        scale_tril: Optional[torch.Tensor] = None,
83
        validate_args=None,
84
    ):
85
        assert (covariance_matrix is not None) + (scale_tril is not None) + (
86
            precision_matrix is not None
87
        ) == 1, "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified."
88

89
        param = next(
90
            p
91
            for p in (covariance_matrix, precision_matrix, scale_tril)
92
            if p is not None
93
        )
94

95
        if param.dim() < 2:
96
            raise ValueError(
97
                "scale_tril must be at least two-dimensional, with optional leading batch dimensions"
98
            )
99

100
        if isinstance(df, Number):
101
            batch_shape = torch.Size(param.shape[:-2])
102
            self.df = torch.tensor(df, dtype=param.dtype, device=param.device)
103
        else:
104
            batch_shape = torch.broadcast_shapes(param.shape[:-2], df.shape)
105
            self.df = df.expand(batch_shape)
106
        event_shape = param.shape[-2:]
107

108
        if self.df.le(event_shape[-1] - 1).any():
109
            raise ValueError(
110
                f"Value of df={df} expected to be greater than ndim - 1 = {event_shape[-1]-1}."
111
            )
112

113
        if scale_tril is not None:
114
            self.scale_tril = param.expand(batch_shape + (-1, -1))
115
        elif covariance_matrix is not None:
116
            self.covariance_matrix = param.expand(batch_shape + (-1, -1))
117
        elif precision_matrix is not None:
118
            self.precision_matrix = param.expand(batch_shape + (-1, -1))
119

120
        self.arg_constraints["df"] = constraints.greater_than(event_shape[-1] - 1)
121
        if self.df.lt(event_shape[-1]).any():
122
            warnings.warn(
123
                "Low df values detected. Singular samples are highly likely to occur for ndim - 1 < df < ndim."
124
            )
125

126
        super().__init__(batch_shape, event_shape, validate_args=validate_args)
127
        self._batch_dims = [-(x + 1) for x in range(len(self._batch_shape))]
128

129
        if scale_tril is not None:
130
            self._unbroadcasted_scale_tril = scale_tril
131
        elif covariance_matrix is not None:
132
            self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix)
133
        else:  # precision_matrix is not None
134
            self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix)
135

136
        # Chi2 distribution is needed for Bartlett decomposition sampling
137
        self._dist_chi2 = torch.distributions.chi2.Chi2(
138
            df=(
139
                self.df.unsqueeze(-1)
140
                - torch.arange(
141
                    self._event_shape[-1],
142
                    dtype=self._unbroadcasted_scale_tril.dtype,
143
                    device=self._unbroadcasted_scale_tril.device,
144
                ).expand(batch_shape + (-1,))
145
            )
146
        )
147

148
    def expand(self, batch_shape, _instance=None):
149
        new = self._get_checked_instance(Wishart, _instance)
150
        batch_shape = torch.Size(batch_shape)
151
        cov_shape = batch_shape + self.event_shape
152
        new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril.expand(cov_shape)
153
        new.df = self.df.expand(batch_shape)
154

155
        new._batch_dims = [-(x + 1) for x in range(len(batch_shape))]
156

157
        if "covariance_matrix" in self.__dict__:
158
            new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
159
        if "scale_tril" in self.__dict__:
160
            new.scale_tril = self.scale_tril.expand(cov_shape)
161
        if "precision_matrix" in self.__dict__:
162
            new.precision_matrix = self.precision_matrix.expand(cov_shape)
163

164
        # Chi2 distribution is needed for Bartlett decomposition sampling
165
        new._dist_chi2 = torch.distributions.chi2.Chi2(
166
            df=(
167
                new.df.unsqueeze(-1)
168
                - torch.arange(
169
                    self.event_shape[-1],
170
                    dtype=new._unbroadcasted_scale_tril.dtype,
171
                    device=new._unbroadcasted_scale_tril.device,
172
                ).expand(batch_shape + (-1,))
173
            )
174
        )
175

176
        super(Wishart, new).__init__(batch_shape, self.event_shape, validate_args=False)
177
        new._validate_args = self._validate_args
178
        return new
179

180
    @lazy_property
181
    def scale_tril(self):
182
        return self._unbroadcasted_scale_tril.expand(
183
            self._batch_shape + self._event_shape
184
        )
185

186
    @lazy_property
187
    def covariance_matrix(self):
188
        return (
189
            self._unbroadcasted_scale_tril
190
            @ self._unbroadcasted_scale_tril.transpose(-2, -1)
191
        ).expand(self._batch_shape + self._event_shape)
192

193
    @lazy_property
194
    def precision_matrix(self):
195
        identity = torch.eye(
196
            self._event_shape[-1],
197
            device=self._unbroadcasted_scale_tril.device,
198
            dtype=self._unbroadcasted_scale_tril.dtype,
199
        )
200
        return torch.cholesky_solve(identity, self._unbroadcasted_scale_tril).expand(
201
            self._batch_shape + self._event_shape
202
        )
203

204
    @property
205
    def mean(self):
206
        return self.df.view(self._batch_shape + (1, 1)) * self.covariance_matrix
207

208
    @property
209
    def mode(self):
210
        factor = self.df - self.covariance_matrix.shape[-1] - 1
211
        factor[factor <= 0] = nan
212
        return factor.view(self._batch_shape + (1, 1)) * self.covariance_matrix
213

214
    @property
215
    def variance(self):
216
        V = self.covariance_matrix  # has shape (batch_shape x event_shape)
217
        diag_V = V.diagonal(dim1=-2, dim2=-1)
218
        return self.df.view(self._batch_shape + (1, 1)) * (
219
            V.pow(2) + torch.einsum("...i,...j->...ij", diag_V, diag_V)
220
        )
221

222
    def _bartlett_sampling(self, sample_shape=torch.Size()):
223
        p = self._event_shape[-1]  # has singleton shape
224

225
        # Implemented Sampling using Bartlett decomposition
226
        noise = _clamp_above_eps(
227
            self._dist_chi2.rsample(sample_shape).sqrt()
228
        ).diag_embed(dim1=-2, dim2=-1)
229

230
        i, j = torch.tril_indices(p, p, offset=-1)
231
        noise[..., i, j] = torch.randn(
232
            torch.Size(sample_shape) + self._batch_shape + (int(p * (p - 1) / 2),),
233
            dtype=noise.dtype,
234
            device=noise.device,
235
        )
236
        chol = self._unbroadcasted_scale_tril @ noise
237
        return chol @ chol.transpose(-2, -1)
238

239
    def rsample(
240
        self, sample_shape: _size = torch.Size(), max_try_correction=None
241
    ) -> torch.Tensor:
242
        r"""
243
        .. warning::
244
            In some cases, sampling algorithm based on Bartlett decomposition may return singular matrix samples.
245
            Several tries to correct singular samples are performed by default, but it may end up returning
246
            singular matrix samples. Singular samples may return `-inf` values in `.log_prob()`.
247
            In those cases, the user should validate the samples and either fix the value of `df`
248
            or adjust `max_try_correction` value for argument in `.rsample` accordingly.
249
        """
250

251
        if max_try_correction is None:
252
            max_try_correction = 3 if torch._C._get_tracing_state() else 10
253

254
        sample_shape = torch.Size(sample_shape)
255
        sample = self._bartlett_sampling(sample_shape)
256

257
        # Below part is to improve numerical stability temporally and should be removed in the future
258
        is_singular = self.support.check(sample)
259
        if self._batch_shape:
260
            is_singular = is_singular.amax(self._batch_dims)
261

262
        if torch._C._get_tracing_state():
263
            # Less optimized version for JIT
264
            for _ in range(max_try_correction):
265
                sample_new = self._bartlett_sampling(sample_shape)
266
                sample = torch.where(is_singular, sample_new, sample)
267

268
                is_singular = ~self.support.check(sample)
269
                if self._batch_shape:
270
                    is_singular = is_singular.amax(self._batch_dims)
271

272
        else:
273
            # More optimized version with data-dependent control flow.
274
            if is_singular.any():
275
                warnings.warn("Singular sample detected.")
276

277
                for _ in range(max_try_correction):
278
                    sample_new = self._bartlett_sampling(is_singular[is_singular].shape)
279
                    sample[is_singular] = sample_new
280

281
                    is_singular_new = ~self.support.check(sample_new)
282
                    if self._batch_shape:
283
                        is_singular_new = is_singular_new.amax(self._batch_dims)
284
                    is_singular[is_singular.clone()] = is_singular_new
285

286
                    if not is_singular.any():
287
                        break
288

289
        return sample
290

291
    def log_prob(self, value):
292
        if self._validate_args:
293
            self._validate_sample(value)
294
        nu = self.df  # has shape (batch_shape)
295
        p = self._event_shape[-1]  # has singleton shape
296
        return (
297
            -nu
298
            * (
299
                p * _log_2 / 2
300
                + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1)
301
                .log()
302
                .sum(-1)
303
            )
304
            - torch.mvlgamma(nu / 2, p=p)
305
            + (nu - p - 1) / 2 * torch.linalg.slogdet(value).logabsdet
306
            - torch.cholesky_solve(value, self._unbroadcasted_scale_tril)
307
            .diagonal(dim1=-2, dim2=-1)
308
            .sum(dim=-1)
309
            / 2
310
        )
311

312
    def entropy(self):
313
        nu = self.df  # has shape (batch_shape)
314
        p = self._event_shape[-1]  # has singleton shape
315
        V = self.covariance_matrix  # has shape (batch_shape x event_shape)
316
        return (
317
            (p + 1)
318
            * (
319
                p * _log_2 / 2
320
                + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1)
321
                .log()
322
                .sum(-1)
323
            )
324
            + torch.mvlgamma(nu / 2, p=p)
325
            - (nu - p - 1) / 2 * _mvdigamma(nu / 2, p=p)
326
            + nu * p / 2
327
        )
328

329
    @property
330
    def _natural_params(self):
331
        nu = self.df  # has shape (batch_shape)
332
        p = self._event_shape[-1]  # has singleton shape
333
        return -self.precision_matrix / 2, (nu - p - 1) / 2
334

335
    def _log_normalizer(self, x, y):
336
        p = self._event_shape[-1]
337
        return (y + (p + 1) / 2) * (
338
            -torch.linalg.slogdet(-2 * x).logabsdet + _log_2 * p
339
        ) + torch.mvlgamma(y + (p + 1) / 2, p=p)
340

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.