pytorch

Форк
0
/
multivariate_normal.py 
265 строк · 10.6 Кб
1
# mypy: allow-untyped-defs
2
import math
3

4
import torch
5
from torch.distributions import constraints
6
from torch.distributions.distribution import Distribution
7
from torch.distributions.utils import _standard_normal, lazy_property
8
from torch.types import _size
9

10

11
__all__ = ["MultivariateNormal"]
12

13

14
def _batch_mv(bmat, bvec):
15
    r"""
16
    Performs a batched matrix-vector product, with compatible but different batch shapes.
17

18
    This function takes as input `bmat`, containing :math:`n \times n` matrices, and
19
    `bvec`, containing length :math:`n` vectors.
20

21
    Both `bmat` and `bvec` may have any number of leading dimensions, which correspond
22
    to a batch shape. They are not necessarily assumed to have the same batch shape,
23
    just ones which can be broadcasted.
24
    """
25
    return torch.matmul(bmat, bvec.unsqueeze(-1)).squeeze(-1)
26

27

28
def _batch_mahalanobis(bL, bx):
29
    r"""
30
    Computes the squared Mahalanobis distance :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}`
31
    for a factored :math:`\mathbf{M} = \mathbf{L}\mathbf{L}^\top`.
32

33
    Accepts batches for both bL and bx. They are not necessarily assumed to have the same batch
34
    shape, but `bL` one should be able to broadcasted to `bx` one.
35
    """
36
    n = bx.size(-1)
37
    bx_batch_shape = bx.shape[:-1]
38

39
    # Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
40
    # we are going to make bx have shape (..., 1, j,  i, 1, n) to apply batched tri.solve
41
    bx_batch_dims = len(bx_batch_shape)
42
    bL_batch_dims = bL.dim() - 2
43
    outer_batch_dims = bx_batch_dims - bL_batch_dims
44
    old_batch_dims = outer_batch_dims + bL_batch_dims
45
    new_batch_dims = outer_batch_dims + 2 * bL_batch_dims
46
    # Reshape bx with the shape (..., 1, i, j, 1, n)
47
    bx_new_shape = bx.shape[:outer_batch_dims]
48
    for sL, sx in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]):
49
        bx_new_shape += (sx // sL, sL)
50
    bx_new_shape += (n,)
51
    bx = bx.reshape(bx_new_shape)
52
    # Permute bx to make it have shape (..., 1, j, i, 1, n)
53
    permute_dims = (
54
        list(range(outer_batch_dims))
55
        + list(range(outer_batch_dims, new_batch_dims, 2))
56
        + list(range(outer_batch_dims + 1, new_batch_dims, 2))
57
        + [new_batch_dims]
58
    )
59
    bx = bx.permute(permute_dims)
60

61
    flat_L = bL.reshape(-1, n, n)  # shape = b x n x n
62
    flat_x = bx.reshape(-1, flat_L.size(0), n)  # shape = c x b x n
63
    flat_x_swap = flat_x.permute(1, 2, 0)  # shape = b x n x c
64
    M_swap = (
65
        torch.linalg.solve_triangular(flat_L, flat_x_swap, upper=False).pow(2).sum(-2)
66
    )  # shape = b x c
67
    M = M_swap.t()  # shape = c x b
68

69
    # Now we revert the above reshape and permute operators.
70
    permuted_M = M.reshape(bx.shape[:-1])  # shape = (..., 1, j, i, 1)
71
    permute_inv_dims = list(range(outer_batch_dims))
72
    for i in range(bL_batch_dims):
73
        permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i]
74
    reshaped_M = permuted_M.permute(permute_inv_dims)  # shape = (..., 1, i, j, 1)
75
    return reshaped_M.reshape(bx_batch_shape)
76

77

78
def _precision_to_scale_tril(P):
79
    # Ref: https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril
80
    Lf = torch.linalg.cholesky(torch.flip(P, (-2, -1)))
81
    L_inv = torch.transpose(torch.flip(Lf, (-2, -1)), -2, -1)
82
    Id = torch.eye(P.shape[-1], dtype=P.dtype, device=P.device)
83
    L = torch.linalg.solve_triangular(L_inv, Id, upper=False)
84
    return L
85

86

87
class MultivariateNormal(Distribution):
88
    r"""
89
    Creates a multivariate normal (also called Gaussian) distribution
90
    parameterized by a mean vector and a covariance matrix.
91

92
    The multivariate normal distribution can be parameterized either
93
    in terms of a positive definite covariance matrix :math:`\mathbf{\Sigma}`
94
    or a positive definite precision matrix :math:`\mathbf{\Sigma}^{-1}`
95
    or a lower-triangular matrix :math:`\mathbf{L}` with positive-valued
96
    diagonal entries, such that
97
    :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`. This triangular matrix
98
    can be obtained via e.g. Cholesky decomposition of the covariance.
99

100
    Example:
101

102
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
103
        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
104
        >>> m = MultivariateNormal(torch.zeros(2), torch.eye(2))
105
        >>> m.sample()  # normally distributed with mean=`[0,0]` and covariance_matrix=`I`
106
        tensor([-0.2102, -0.5429])
107

108
    Args:
109
        loc (Tensor): mean of the distribution
110
        covariance_matrix (Tensor): positive-definite covariance matrix
111
        precision_matrix (Tensor): positive-definite precision matrix
112
        scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal
113

114
    Note:
115
        Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or
116
        :attr:`scale_tril` can be specified.
117

118
        Using :attr:`scale_tril` will be more efficient: all computations internally
119
        are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or
120
        :attr:`precision_matrix` is passed instead, it is only used to compute
121
        the corresponding lower triangular matrices using a Cholesky decomposition.
122
    """
123
    arg_constraints = {
124
        "loc": constraints.real_vector,
125
        "covariance_matrix": constraints.positive_definite,
126
        "precision_matrix": constraints.positive_definite,
127
        "scale_tril": constraints.lower_cholesky,
128
    }
129
    support = constraints.real_vector
130
    has_rsample = True
131

132
    def __init__(
133
        self,
134
        loc,
135
        covariance_matrix=None,
136
        precision_matrix=None,
137
        scale_tril=None,
138
        validate_args=None,
139
    ):
140
        if loc.dim() < 1:
141
            raise ValueError("loc must be at least one-dimensional.")
142
        if (covariance_matrix is not None) + (scale_tril is not None) + (
143
            precision_matrix is not None
144
        ) != 1:
145
            raise ValueError(
146
                "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified."
147
            )
148

149
        if scale_tril is not None:
150
            if scale_tril.dim() < 2:
151
                raise ValueError(
152
                    "scale_tril matrix must be at least two-dimensional, "
153
                    "with optional leading batch dimensions"
154
                )
155
            batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1])
156
            self.scale_tril = scale_tril.expand(batch_shape + (-1, -1))
157
        elif covariance_matrix is not None:
158
            if covariance_matrix.dim() < 2:
159
                raise ValueError(
160
                    "covariance_matrix must be at least two-dimensional, "
161
                    "with optional leading batch dimensions"
162
                )
163
            batch_shape = torch.broadcast_shapes(
164
                covariance_matrix.shape[:-2], loc.shape[:-1]
165
            )
166
            self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1))
167
        else:
168
            if precision_matrix.dim() < 2:
169
                raise ValueError(
170
                    "precision_matrix must be at least two-dimensional, "
171
                    "with optional leading batch dimensions"
172
                )
173
            batch_shape = torch.broadcast_shapes(
174
                precision_matrix.shape[:-2], loc.shape[:-1]
175
            )
176
            self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1))
177
        self.loc = loc.expand(batch_shape + (-1,))
178

179
        event_shape = self.loc.shape[-1:]
180
        super().__init__(batch_shape, event_shape, validate_args=validate_args)
181

182
        if scale_tril is not None:
183
            self._unbroadcasted_scale_tril = scale_tril
184
        elif covariance_matrix is not None:
185
            self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix)
186
        else:  # precision_matrix is not None
187
            self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix)
188

189
    def expand(self, batch_shape, _instance=None):
190
        new = self._get_checked_instance(MultivariateNormal, _instance)
191
        batch_shape = torch.Size(batch_shape)
192
        loc_shape = batch_shape + self.event_shape
193
        cov_shape = batch_shape + self.event_shape + self.event_shape
194
        new.loc = self.loc.expand(loc_shape)
195
        new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril
196
        if "covariance_matrix" in self.__dict__:
197
            new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
198
        if "scale_tril" in self.__dict__:
199
            new.scale_tril = self.scale_tril.expand(cov_shape)
200
        if "precision_matrix" in self.__dict__:
201
            new.precision_matrix = self.precision_matrix.expand(cov_shape)
202
        super(MultivariateNormal, new).__init__(
203
            batch_shape, self.event_shape, validate_args=False
204
        )
205
        new._validate_args = self._validate_args
206
        return new
207

208
    @lazy_property
209
    def scale_tril(self):
210
        return self._unbroadcasted_scale_tril.expand(
211
            self._batch_shape + self._event_shape + self._event_shape
212
        )
213

214
    @lazy_property
215
    def covariance_matrix(self):
216
        return torch.matmul(
217
            self._unbroadcasted_scale_tril, self._unbroadcasted_scale_tril.mT
218
        ).expand(self._batch_shape + self._event_shape + self._event_shape)
219

220
    @lazy_property
221
    def precision_matrix(self):
222
        return torch.cholesky_inverse(self._unbroadcasted_scale_tril).expand(
223
            self._batch_shape + self._event_shape + self._event_shape
224
        )
225

226
    @property
227
    def mean(self):
228
        return self.loc
229

230
    @property
231
    def mode(self):
232
        return self.loc
233

234
    @property
235
    def variance(self):
236
        return (
237
            self._unbroadcasted_scale_tril.pow(2)
238
            .sum(-1)
239
            .expand(self._batch_shape + self._event_shape)
240
        )
241

242
    def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor:
243
        shape = self._extended_shape(sample_shape)
244
        eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
245
        return self.loc + _batch_mv(self._unbroadcasted_scale_tril, eps)
246

247
    def log_prob(self, value):
248
        if self._validate_args:
249
            self._validate_sample(value)
250
        diff = value - self.loc
251
        M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)
252
        half_log_det = (
253
            self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
254
        )
255
        return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + M) - half_log_det
256

257
    def entropy(self):
258
        half_log_det = (
259
            self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
260
        )
261
        H = 0.5 * self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + half_log_det
262
        if len(self._batch_shape) == 0:
263
            return H
264
        else:
265
            return H.expand(self._batch_shape)
266

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.