pytorch
265 строк · 10.6 Кб
1# mypy: allow-untyped-defs
2import math3
4import torch5from torch.distributions import constraints6from torch.distributions.distribution import Distribution7from torch.distributions.utils import _standard_normal, lazy_property8from torch.types import _size9
10
11__all__ = ["MultivariateNormal"]12
13
14def _batch_mv(bmat, bvec):15r"""16Performs a batched matrix-vector product, with compatible but different batch shapes.
17
18This function takes as input `bmat`, containing :math:`n \times n` matrices, and
19`bvec`, containing length :math:`n` vectors.
20
21Both `bmat` and `bvec` may have any number of leading dimensions, which correspond
22to a batch shape. They are not necessarily assumed to have the same batch shape,
23just ones which can be broadcasted.
24"""
25return torch.matmul(bmat, bvec.unsqueeze(-1)).squeeze(-1)26
27
28def _batch_mahalanobis(bL, bx):29r"""30Computes the squared Mahalanobis distance :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}`
31for a factored :math:`\mathbf{M} = \mathbf{L}\mathbf{L}^\top`.
32
33Accepts batches for both bL and bx. They are not necessarily assumed to have the same batch
34shape, but `bL` one should be able to broadcasted to `bx` one.
35"""
36n = bx.size(-1)37bx_batch_shape = bx.shape[:-1]38
39# Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),40# we are going to make bx have shape (..., 1, j, i, 1, n) to apply batched tri.solve41bx_batch_dims = len(bx_batch_shape)42bL_batch_dims = bL.dim() - 243outer_batch_dims = bx_batch_dims - bL_batch_dims44old_batch_dims = outer_batch_dims + bL_batch_dims45new_batch_dims = outer_batch_dims + 2 * bL_batch_dims46# Reshape bx with the shape (..., 1, i, j, 1, n)47bx_new_shape = bx.shape[:outer_batch_dims]48for sL, sx in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]):49bx_new_shape += (sx // sL, sL)50bx_new_shape += (n,)51bx = bx.reshape(bx_new_shape)52# Permute bx to make it have shape (..., 1, j, i, 1, n)53permute_dims = (54list(range(outer_batch_dims))55+ list(range(outer_batch_dims, new_batch_dims, 2))56+ list(range(outer_batch_dims + 1, new_batch_dims, 2))57+ [new_batch_dims]58)59bx = bx.permute(permute_dims)60
61flat_L = bL.reshape(-1, n, n) # shape = b x n x n62flat_x = bx.reshape(-1, flat_L.size(0), n) # shape = c x b x n63flat_x_swap = flat_x.permute(1, 2, 0) # shape = b x n x c64M_swap = (65torch.linalg.solve_triangular(flat_L, flat_x_swap, upper=False).pow(2).sum(-2)66) # shape = b x c67M = M_swap.t() # shape = c x b68
69# Now we revert the above reshape and permute operators.70permuted_M = M.reshape(bx.shape[:-1]) # shape = (..., 1, j, i, 1)71permute_inv_dims = list(range(outer_batch_dims))72for i in range(bL_batch_dims):73permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i]74reshaped_M = permuted_M.permute(permute_inv_dims) # shape = (..., 1, i, j, 1)75return reshaped_M.reshape(bx_batch_shape)76
77
78def _precision_to_scale_tril(P):79# Ref: https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril80Lf = torch.linalg.cholesky(torch.flip(P, (-2, -1)))81L_inv = torch.transpose(torch.flip(Lf, (-2, -1)), -2, -1)82Id = torch.eye(P.shape[-1], dtype=P.dtype, device=P.device)83L = torch.linalg.solve_triangular(L_inv, Id, upper=False)84return L85
86
87class MultivariateNormal(Distribution):88r"""89Creates a multivariate normal (also called Gaussian) distribution
90parameterized by a mean vector and a covariance matrix.
91
92The multivariate normal distribution can be parameterized either
93in terms of a positive definite covariance matrix :math:`\mathbf{\Sigma}`
94or a positive definite precision matrix :math:`\mathbf{\Sigma}^{-1}`
95or a lower-triangular matrix :math:`\mathbf{L}` with positive-valued
96diagonal entries, such that
97:math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`. This triangular matrix
98can be obtained via e.g. Cholesky decomposition of the covariance.
99
100Example:
101
102>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
103>>> # xdoctest: +IGNORE_WANT("non-deterministic")
104>>> m = MultivariateNormal(torch.zeros(2), torch.eye(2))
105>>> m.sample() # normally distributed with mean=`[0,0]` and covariance_matrix=`I`
106tensor([-0.2102, -0.5429])
107
108Args:
109loc (Tensor): mean of the distribution
110covariance_matrix (Tensor): positive-definite covariance matrix
111precision_matrix (Tensor): positive-definite precision matrix
112scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal
113
114Note:
115Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or
116:attr:`scale_tril` can be specified.
117
118Using :attr:`scale_tril` will be more efficient: all computations internally
119are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or
120:attr:`precision_matrix` is passed instead, it is only used to compute
121the corresponding lower triangular matrices using a Cholesky decomposition.
122"""
123arg_constraints = {124"loc": constraints.real_vector,125"covariance_matrix": constraints.positive_definite,126"precision_matrix": constraints.positive_definite,127"scale_tril": constraints.lower_cholesky,128}129support = constraints.real_vector130has_rsample = True131
132def __init__(133self,134loc,135covariance_matrix=None,136precision_matrix=None,137scale_tril=None,138validate_args=None,139):140if loc.dim() < 1:141raise ValueError("loc must be at least one-dimensional.")142if (covariance_matrix is not None) + (scale_tril is not None) + (143precision_matrix is not None144) != 1:145raise ValueError(146"Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified."147)148
149if scale_tril is not None:150if scale_tril.dim() < 2:151raise ValueError(152"scale_tril matrix must be at least two-dimensional, "153"with optional leading batch dimensions"154)155batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1])156self.scale_tril = scale_tril.expand(batch_shape + (-1, -1))157elif covariance_matrix is not None:158if covariance_matrix.dim() < 2:159raise ValueError(160"covariance_matrix must be at least two-dimensional, "161"with optional leading batch dimensions"162)163batch_shape = torch.broadcast_shapes(164covariance_matrix.shape[:-2], loc.shape[:-1]165)166self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1))167else:168if precision_matrix.dim() < 2:169raise ValueError(170"precision_matrix must be at least two-dimensional, "171"with optional leading batch dimensions"172)173batch_shape = torch.broadcast_shapes(174precision_matrix.shape[:-2], loc.shape[:-1]175)176self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1))177self.loc = loc.expand(batch_shape + (-1,))178
179event_shape = self.loc.shape[-1:]180super().__init__(batch_shape, event_shape, validate_args=validate_args)181
182if scale_tril is not None:183self._unbroadcasted_scale_tril = scale_tril184elif covariance_matrix is not None:185self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix)186else: # precision_matrix is not None187self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix)188
189def expand(self, batch_shape, _instance=None):190new = self._get_checked_instance(MultivariateNormal, _instance)191batch_shape = torch.Size(batch_shape)192loc_shape = batch_shape + self.event_shape193cov_shape = batch_shape + self.event_shape + self.event_shape194new.loc = self.loc.expand(loc_shape)195new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril196if "covariance_matrix" in self.__dict__:197new.covariance_matrix = self.covariance_matrix.expand(cov_shape)198if "scale_tril" in self.__dict__:199new.scale_tril = self.scale_tril.expand(cov_shape)200if "precision_matrix" in self.__dict__:201new.precision_matrix = self.precision_matrix.expand(cov_shape)202super(MultivariateNormal, new).__init__(203batch_shape, self.event_shape, validate_args=False204)205new._validate_args = self._validate_args206return new207
208@lazy_property209def scale_tril(self):210return self._unbroadcasted_scale_tril.expand(211self._batch_shape + self._event_shape + self._event_shape212)213
214@lazy_property215def covariance_matrix(self):216return torch.matmul(217self._unbroadcasted_scale_tril, self._unbroadcasted_scale_tril.mT218).expand(self._batch_shape + self._event_shape + self._event_shape)219
220@lazy_property221def precision_matrix(self):222return torch.cholesky_inverse(self._unbroadcasted_scale_tril).expand(223self._batch_shape + self._event_shape + self._event_shape224)225
226@property227def mean(self):228return self.loc229
230@property231def mode(self):232return self.loc233
234@property235def variance(self):236return (237self._unbroadcasted_scale_tril.pow(2)238.sum(-1)239.expand(self._batch_shape + self._event_shape)240)241
242def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor:243shape = self._extended_shape(sample_shape)244eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)245return self.loc + _batch_mv(self._unbroadcasted_scale_tril, eps)246
247def log_prob(self, value):248if self._validate_args:249self._validate_sample(value)250diff = value - self.loc251M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)252half_log_det = (253self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)254)255return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + M) - half_log_det256
257def entropy(self):258half_log_det = (259self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)260)261H = 0.5 * self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + half_log_det262if len(self._batch_shape) == 0:263return H264else:265return H.expand(self._batch_shape)266