pytorch

Форк
0
/
lowrank_multivariate_normal.py 
237 строк · 9.6 Кб
1
import math
2

3
import torch
4
from torch.distributions import constraints
5
from torch.distributions.distribution import Distribution
6
from torch.distributions.multivariate_normal import _batch_mahalanobis, _batch_mv
7
from torch.distributions.utils import _standard_normal, lazy_property
8

9
__all__ = ["LowRankMultivariateNormal"]
10

11

12
def _batch_capacitance_tril(W, D):
13
    r"""
14
    Computes Cholesky of :math:`I + W.T @ inv(D) @ W` for a batch of matrices :math:`W`
15
    and a batch of vectors :math:`D`.
16
    """
17
    m = W.size(-1)
18
    Wt_Dinv = W.mT / D.unsqueeze(-2)
19
    K = torch.matmul(Wt_Dinv, W).contiguous()
20
    K.view(-1, m * m)[:, :: m + 1] += 1  # add identity matrix to K
21
    return torch.linalg.cholesky(K)
22

23

24
def _batch_lowrank_logdet(W, D, capacitance_tril):
25
    r"""
26
    Uses "matrix determinant lemma"::
27
        log|W @ W.T + D| = log|C| + log|D|,
28
    where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute
29
    the log determinant.
30
    """
31
    return 2 * capacitance_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + D.log().sum(
32
        -1
33
    )
34

35

36
def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril):
37
    r"""
38
    Uses "Woodbury matrix identity"::
39
        inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D),
40
    where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the squared
41
    Mahalanobis distance :math:`x.T @ inv(W @ W.T + D) @ x`.
42
    """
43
    Wt_Dinv = W.mT / D.unsqueeze(-2)
44
    Wt_Dinv_x = _batch_mv(Wt_Dinv, x)
45
    mahalanobis_term1 = (x.pow(2) / D).sum(-1)
46
    mahalanobis_term2 = _batch_mahalanobis(capacitance_tril, Wt_Dinv_x)
47
    return mahalanobis_term1 - mahalanobis_term2
48

49

50
class LowRankMultivariateNormal(Distribution):
51
    r"""
52
    Creates a multivariate normal distribution with covariance matrix having a low-rank form
53
    parameterized by :attr:`cov_factor` and :attr:`cov_diag`::
54

55
        covariance_matrix = cov_factor @ cov_factor.T + cov_diag
56

57
    Example:
58
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
59
        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
60
        >>> m = LowRankMultivariateNormal(torch.zeros(2), torch.tensor([[1.], [0.]]), torch.ones(2))
61
        >>> m.sample()  # normally distributed with mean=`[0,0]`, cov_factor=`[[1],[0]]`, cov_diag=`[1,1]`
62
        tensor([-0.2102, -0.5429])
63

64
    Args:
65
        loc (Tensor): mean of the distribution with shape `batch_shape + event_shape`
66
        cov_factor (Tensor): factor part of low-rank form of covariance matrix with shape
67
            `batch_shape + event_shape + (rank,)`
68
        cov_diag (Tensor): diagonal part of low-rank form of covariance matrix with shape
69
            `batch_shape + event_shape`
70

71
    Note:
72
        The computation for determinant and inverse of covariance matrix is avoided when
73
        `cov_factor.shape[1] << cov_factor.shape[0]` thanks to `Woodbury matrix identity
74
        <https://en.wikipedia.org/wiki/Woodbury_matrix_identity>`_ and
75
        `matrix determinant lemma <https://en.wikipedia.org/wiki/Matrix_determinant_lemma>`_.
76
        Thanks to these formulas, we just need to compute the determinant and inverse of
77
        the small size "capacitance" matrix::
78

79
            capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor
80
    """
81
    arg_constraints = {
82
        "loc": constraints.real_vector,
83
        "cov_factor": constraints.independent(constraints.real, 2),
84
        "cov_diag": constraints.independent(constraints.positive, 1),
85
    }
86
    support = constraints.real_vector
87
    has_rsample = True
88

89
    def __init__(self, loc, cov_factor, cov_diag, validate_args=None):
90
        if loc.dim() < 1:
91
            raise ValueError("loc must be at least one-dimensional.")
92
        event_shape = loc.shape[-1:]
93
        if cov_factor.dim() < 2:
94
            raise ValueError(
95
                "cov_factor must be at least two-dimensional, "
96
                "with optional leading batch dimensions"
97
            )
98
        if cov_factor.shape[-2:-1] != event_shape:
99
            raise ValueError(
100
                f"cov_factor must be a batch of matrices with shape {event_shape[0]} x m"
101
            )
102
        if cov_diag.shape[-1:] != event_shape:
103
            raise ValueError(
104
                f"cov_diag must be a batch of vectors with shape {event_shape}"
105
            )
106

107
        loc_ = loc.unsqueeze(-1)
108
        cov_diag_ = cov_diag.unsqueeze(-1)
109
        try:
110
            loc_, self.cov_factor, cov_diag_ = torch.broadcast_tensors(
111
                loc_, cov_factor, cov_diag_
112
            )
113
        except RuntimeError as e:
114
            raise ValueError(
115
                f"Incompatible batch shapes: loc {loc.shape}, cov_factor {cov_factor.shape}, cov_diag {cov_diag.shape}"
116
            ) from e
117
        self.loc = loc_[..., 0]
118
        self.cov_diag = cov_diag_[..., 0]
119
        batch_shape = self.loc.shape[:-1]
120

121
        self._unbroadcasted_cov_factor = cov_factor
122
        self._unbroadcasted_cov_diag = cov_diag
123
        self._capacitance_tril = _batch_capacitance_tril(cov_factor, cov_diag)
124
        super().__init__(batch_shape, event_shape, validate_args=validate_args)
125

126
    def expand(self, batch_shape, _instance=None):
127
        new = self._get_checked_instance(LowRankMultivariateNormal, _instance)
128
        batch_shape = torch.Size(batch_shape)
129
        loc_shape = batch_shape + self.event_shape
130
        new.loc = self.loc.expand(loc_shape)
131
        new.cov_diag = self.cov_diag.expand(loc_shape)
132
        new.cov_factor = self.cov_factor.expand(loc_shape + self.cov_factor.shape[-1:])
133
        new._unbroadcasted_cov_factor = self._unbroadcasted_cov_factor
134
        new._unbroadcasted_cov_diag = self._unbroadcasted_cov_diag
135
        new._capacitance_tril = self._capacitance_tril
136
        super(LowRankMultivariateNormal, new).__init__(
137
            batch_shape, self.event_shape, validate_args=False
138
        )
139
        new._validate_args = self._validate_args
140
        return new
141

142
    @property
143
    def mean(self):
144
        return self.loc
145

146
    @property
147
    def mode(self):
148
        return self.loc
149

150
    @lazy_property
151
    def variance(self):
152
        return (
153
            self._unbroadcasted_cov_factor.pow(2).sum(-1) + self._unbroadcasted_cov_diag
154
        ).expand(self._batch_shape + self._event_shape)
155

156
    @lazy_property
157
    def scale_tril(self):
158
        # The following identity is used to increase the numerically computation stability
159
        # for Cholesky decomposition (see http://www.gaussianprocess.org/gpml/, Section 3.4.3):
160
        #     W @ W.T + D = D1/2 @ (I + D-1/2 @ W @ W.T @ D-1/2) @ D1/2
161
        # The matrix "I + D-1/2 @ W @ W.T @ D-1/2" has eigenvalues bounded from below by 1,
162
        # hence it is well-conditioned and safe to take Cholesky decomposition.
163
        n = self._event_shape[0]
164
        cov_diag_sqrt_unsqueeze = self._unbroadcasted_cov_diag.sqrt().unsqueeze(-1)
165
        Dinvsqrt_W = self._unbroadcasted_cov_factor / cov_diag_sqrt_unsqueeze
166
        K = torch.matmul(Dinvsqrt_W, Dinvsqrt_W.mT).contiguous()
167
        K.view(-1, n * n)[:, :: n + 1] += 1  # add identity matrix to K
168
        scale_tril = cov_diag_sqrt_unsqueeze * torch.linalg.cholesky(K)
169
        return scale_tril.expand(
170
            self._batch_shape + self._event_shape + self._event_shape
171
        )
172

173
    @lazy_property
174
    def covariance_matrix(self):
175
        covariance_matrix = torch.matmul(
176
            self._unbroadcasted_cov_factor, self._unbroadcasted_cov_factor.mT
177
        ) + torch.diag_embed(self._unbroadcasted_cov_diag)
178
        return covariance_matrix.expand(
179
            self._batch_shape + self._event_shape + self._event_shape
180
        )
181

182
    @lazy_property
183
    def precision_matrix(self):
184
        # We use "Woodbury matrix identity" to take advantage of low rank form::
185
        #     inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D)
186
        # where :math:`C` is the capacitance matrix.
187
        Wt_Dinv = (
188
            self._unbroadcasted_cov_factor.mT
189
            / self._unbroadcasted_cov_diag.unsqueeze(-2)
190
        )
191
        A = torch.linalg.solve_triangular(self._capacitance_tril, Wt_Dinv, upper=False)
192
        precision_matrix = (
193
            torch.diag_embed(self._unbroadcasted_cov_diag.reciprocal()) - A.mT @ A
194
        )
195
        return precision_matrix.expand(
196
            self._batch_shape + self._event_shape + self._event_shape
197
        )
198

199
    def rsample(self, sample_shape=torch.Size()):
200
        shape = self._extended_shape(sample_shape)
201
        W_shape = shape[:-1] + self.cov_factor.shape[-1:]
202
        eps_W = _standard_normal(W_shape, dtype=self.loc.dtype, device=self.loc.device)
203
        eps_D = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
204
        return (
205
            self.loc
206
            + _batch_mv(self._unbroadcasted_cov_factor, eps_W)
207
            + self._unbroadcasted_cov_diag.sqrt() * eps_D
208
        )
209

210
    def log_prob(self, value):
211
        if self._validate_args:
212
            self._validate_sample(value)
213
        diff = value - self.loc
214
        M = _batch_lowrank_mahalanobis(
215
            self._unbroadcasted_cov_factor,
216
            self._unbroadcasted_cov_diag,
217
            diff,
218
            self._capacitance_tril,
219
        )
220
        log_det = _batch_lowrank_logdet(
221
            self._unbroadcasted_cov_factor,
222
            self._unbroadcasted_cov_diag,
223
            self._capacitance_tril,
224
        )
225
        return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + log_det + M)
226

227
    def entropy(self):
228
        log_det = _batch_lowrank_logdet(
229
            self._unbroadcasted_cov_factor,
230
            self._unbroadcasted_cov_diag,
231
            self._capacitance_tril,
232
        )
233
        H = 0.5 * (self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + log_det)
234
        if len(self._batch_shape) == 0:
235
            return H
236
        else:
237
            return H.expand(self._batch_shape)
238

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.