pytorch
237 строк · 9.6 Кб
1import math
2
3import torch
4from torch.distributions import constraints
5from torch.distributions.distribution import Distribution
6from torch.distributions.multivariate_normal import _batch_mahalanobis, _batch_mv
7from torch.distributions.utils import _standard_normal, lazy_property
8
9__all__ = ["LowRankMultivariateNormal"]
10
11
12def _batch_capacitance_tril(W, D):
13r"""
14Computes Cholesky of :math:`I + W.T @ inv(D) @ W` for a batch of matrices :math:`W`
15and a batch of vectors :math:`D`.
16"""
17m = W.size(-1)
18Wt_Dinv = W.mT / D.unsqueeze(-2)
19K = torch.matmul(Wt_Dinv, W).contiguous()
20K.view(-1, m * m)[:, :: m + 1] += 1 # add identity matrix to K
21return torch.linalg.cholesky(K)
22
23
24def _batch_lowrank_logdet(W, D, capacitance_tril):
25r"""
26Uses "matrix determinant lemma"::
27log|W @ W.T + D| = log|C| + log|D|,
28where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute
29the log determinant.
30"""
31return 2 * capacitance_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + D.log().sum(
32-1
33)
34
35
36def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril):
37r"""
38Uses "Woodbury matrix identity"::
39inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D),
40where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the squared
41Mahalanobis distance :math:`x.T @ inv(W @ W.T + D) @ x`.
42"""
43Wt_Dinv = W.mT / D.unsqueeze(-2)
44Wt_Dinv_x = _batch_mv(Wt_Dinv, x)
45mahalanobis_term1 = (x.pow(2) / D).sum(-1)
46mahalanobis_term2 = _batch_mahalanobis(capacitance_tril, Wt_Dinv_x)
47return mahalanobis_term1 - mahalanobis_term2
48
49
50class LowRankMultivariateNormal(Distribution):
51r"""
52Creates a multivariate normal distribution with covariance matrix having a low-rank form
53parameterized by :attr:`cov_factor` and :attr:`cov_diag`::
54
55covariance_matrix = cov_factor @ cov_factor.T + cov_diag
56
57Example:
58>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
59>>> # xdoctest: +IGNORE_WANT("non-deterministic")
60>>> m = LowRankMultivariateNormal(torch.zeros(2), torch.tensor([[1.], [0.]]), torch.ones(2))
61>>> m.sample() # normally distributed with mean=`[0,0]`, cov_factor=`[[1],[0]]`, cov_diag=`[1,1]`
62tensor([-0.2102, -0.5429])
63
64Args:
65loc (Tensor): mean of the distribution with shape `batch_shape + event_shape`
66cov_factor (Tensor): factor part of low-rank form of covariance matrix with shape
67`batch_shape + event_shape + (rank,)`
68cov_diag (Tensor): diagonal part of low-rank form of covariance matrix with shape
69`batch_shape + event_shape`
70
71Note:
72The computation for determinant and inverse of covariance matrix is avoided when
73`cov_factor.shape[1] << cov_factor.shape[0]` thanks to `Woodbury matrix identity
74<https://en.wikipedia.org/wiki/Woodbury_matrix_identity>`_ and
75`matrix determinant lemma <https://en.wikipedia.org/wiki/Matrix_determinant_lemma>`_.
76Thanks to these formulas, we just need to compute the determinant and inverse of
77the small size "capacitance" matrix::
78
79capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor
80"""
81arg_constraints = {
82"loc": constraints.real_vector,
83"cov_factor": constraints.independent(constraints.real, 2),
84"cov_diag": constraints.independent(constraints.positive, 1),
85}
86support = constraints.real_vector
87has_rsample = True
88
89def __init__(self, loc, cov_factor, cov_diag, validate_args=None):
90if loc.dim() < 1:
91raise ValueError("loc must be at least one-dimensional.")
92event_shape = loc.shape[-1:]
93if cov_factor.dim() < 2:
94raise ValueError(
95"cov_factor must be at least two-dimensional, "
96"with optional leading batch dimensions"
97)
98if cov_factor.shape[-2:-1] != event_shape:
99raise ValueError(
100f"cov_factor must be a batch of matrices with shape {event_shape[0]} x m"
101)
102if cov_diag.shape[-1:] != event_shape:
103raise ValueError(
104f"cov_diag must be a batch of vectors with shape {event_shape}"
105)
106
107loc_ = loc.unsqueeze(-1)
108cov_diag_ = cov_diag.unsqueeze(-1)
109try:
110loc_, self.cov_factor, cov_diag_ = torch.broadcast_tensors(
111loc_, cov_factor, cov_diag_
112)
113except RuntimeError as e:
114raise ValueError(
115f"Incompatible batch shapes: loc {loc.shape}, cov_factor {cov_factor.shape}, cov_diag {cov_diag.shape}"
116) from e
117self.loc = loc_[..., 0]
118self.cov_diag = cov_diag_[..., 0]
119batch_shape = self.loc.shape[:-1]
120
121self._unbroadcasted_cov_factor = cov_factor
122self._unbroadcasted_cov_diag = cov_diag
123self._capacitance_tril = _batch_capacitance_tril(cov_factor, cov_diag)
124super().__init__(batch_shape, event_shape, validate_args=validate_args)
125
126def expand(self, batch_shape, _instance=None):
127new = self._get_checked_instance(LowRankMultivariateNormal, _instance)
128batch_shape = torch.Size(batch_shape)
129loc_shape = batch_shape + self.event_shape
130new.loc = self.loc.expand(loc_shape)
131new.cov_diag = self.cov_diag.expand(loc_shape)
132new.cov_factor = self.cov_factor.expand(loc_shape + self.cov_factor.shape[-1:])
133new._unbroadcasted_cov_factor = self._unbroadcasted_cov_factor
134new._unbroadcasted_cov_diag = self._unbroadcasted_cov_diag
135new._capacitance_tril = self._capacitance_tril
136super(LowRankMultivariateNormal, new).__init__(
137batch_shape, self.event_shape, validate_args=False
138)
139new._validate_args = self._validate_args
140return new
141
142@property
143def mean(self):
144return self.loc
145
146@property
147def mode(self):
148return self.loc
149
150@lazy_property
151def variance(self):
152return (
153self._unbroadcasted_cov_factor.pow(2).sum(-1) + self._unbroadcasted_cov_diag
154).expand(self._batch_shape + self._event_shape)
155
156@lazy_property
157def scale_tril(self):
158# The following identity is used to increase the numerically computation stability
159# for Cholesky decomposition (see http://www.gaussianprocess.org/gpml/, Section 3.4.3):
160# W @ W.T + D = D1/2 @ (I + D-1/2 @ W @ W.T @ D-1/2) @ D1/2
161# The matrix "I + D-1/2 @ W @ W.T @ D-1/2" has eigenvalues bounded from below by 1,
162# hence it is well-conditioned and safe to take Cholesky decomposition.
163n = self._event_shape[0]
164cov_diag_sqrt_unsqueeze = self._unbroadcasted_cov_diag.sqrt().unsqueeze(-1)
165Dinvsqrt_W = self._unbroadcasted_cov_factor / cov_diag_sqrt_unsqueeze
166K = torch.matmul(Dinvsqrt_W, Dinvsqrt_W.mT).contiguous()
167K.view(-1, n * n)[:, :: n + 1] += 1 # add identity matrix to K
168scale_tril = cov_diag_sqrt_unsqueeze * torch.linalg.cholesky(K)
169return scale_tril.expand(
170self._batch_shape + self._event_shape + self._event_shape
171)
172
173@lazy_property
174def covariance_matrix(self):
175covariance_matrix = torch.matmul(
176self._unbroadcasted_cov_factor, self._unbroadcasted_cov_factor.mT
177) + torch.diag_embed(self._unbroadcasted_cov_diag)
178return covariance_matrix.expand(
179self._batch_shape + self._event_shape + self._event_shape
180)
181
182@lazy_property
183def precision_matrix(self):
184# We use "Woodbury matrix identity" to take advantage of low rank form::
185# inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D)
186# where :math:`C` is the capacitance matrix.
187Wt_Dinv = (
188self._unbroadcasted_cov_factor.mT
189/ self._unbroadcasted_cov_diag.unsqueeze(-2)
190)
191A = torch.linalg.solve_triangular(self._capacitance_tril, Wt_Dinv, upper=False)
192precision_matrix = (
193torch.diag_embed(self._unbroadcasted_cov_diag.reciprocal()) - A.mT @ A
194)
195return precision_matrix.expand(
196self._batch_shape + self._event_shape + self._event_shape
197)
198
199def rsample(self, sample_shape=torch.Size()):
200shape = self._extended_shape(sample_shape)
201W_shape = shape[:-1] + self.cov_factor.shape[-1:]
202eps_W = _standard_normal(W_shape, dtype=self.loc.dtype, device=self.loc.device)
203eps_D = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
204return (
205self.loc
206+ _batch_mv(self._unbroadcasted_cov_factor, eps_W)
207+ self._unbroadcasted_cov_diag.sqrt() * eps_D
208)
209
210def log_prob(self, value):
211if self._validate_args:
212self._validate_sample(value)
213diff = value - self.loc
214M = _batch_lowrank_mahalanobis(
215self._unbroadcasted_cov_factor,
216self._unbroadcasted_cov_diag,
217diff,
218self._capacitance_tril,
219)
220log_det = _batch_lowrank_logdet(
221self._unbroadcasted_cov_factor,
222self._unbroadcasted_cov_diag,
223self._capacitance_tril,
224)
225return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + log_det + M)
226
227def entropy(self):
228log_det = _batch_lowrank_logdet(
229self._unbroadcasted_cov_factor,
230self._unbroadcasted_cov_diag,
231self._capacitance_tril,
232)
233H = 0.5 * (self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + log_det)
234if len(self._batch_shape) == 0:
235return H
236else:
237return H.expand(self._batch_shape)
238