pytorch
339 строк · 13.4 Кб
1# mypy: allow-untyped-defs
2import math3import warnings4from numbers import Number5from typing import Optional, Union6
7import torch8from torch import nan9from torch.distributions import constraints10from torch.distributions.exp_family import ExponentialFamily11from torch.distributions.multivariate_normal import _precision_to_scale_tril12from torch.distributions.utils import lazy_property13from torch.types import _size14
15
16__all__ = ["Wishart"]17
18_log_2 = math.log(2)19
20
21def _mvdigamma(x: torch.Tensor, p: int) -> torch.Tensor:22assert x.gt((p - 1) / 2).all(), "Wrong domain for multivariate digamma function."23return torch.digamma(24x.unsqueeze(-1)25- torch.arange(p, dtype=x.dtype, device=x.device).div(2).expand(x.shape + (-1,))26).sum(-1)27
28
29def _clamp_above_eps(x: torch.Tensor) -> torch.Tensor:30# We assume positive input for this function31return x.clamp(min=torch.finfo(x.dtype).eps)32
33
34class Wishart(ExponentialFamily):35r"""36Creates a Wishart distribution parameterized by a symmetric positive definite matrix :math:`\Sigma`,
37or its Cholesky decomposition :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`
38
39Example:
40>>> # xdoctest: +SKIP("FIXME: scale_tril must be at least two-dimensional")
41>>> m = Wishart(torch.Tensor([2]), covariance_matrix=torch.eye(2))
42>>> m.sample() # Wishart distributed with mean=`df * I` and
43>>> # variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j
44
45Args:
46df (float or Tensor): real-valued parameter larger than the (dimension of Square matrix) - 1
47covariance_matrix (Tensor): positive-definite covariance matrix
48precision_matrix (Tensor): positive-definite precision matrix
49scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal
50Note:
51Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or
52:attr:`scale_tril` can be specified.
53Using :attr:`scale_tril` will be more efficient: all computations internally
54are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or
55:attr:`precision_matrix` is passed instead, it is only used to compute
56the corresponding lower triangular matrices using a Cholesky decomposition.
57'torch.distributions.LKJCholesky' is a restricted Wishart distribution.[1]
58
59**References**
60
61[1] Wang, Z., Wu, Y. and Chu, H., 2018. `On equivalence of the LKJ distribution and the restricted Wishart distribution`.
62[2] Sawyer, S., 2007. `Wishart Distributions and Inverse-Wishart Sampling`.
63[3] Anderson, T. W., 2003. `An Introduction to Multivariate Statistical Analysis (3rd ed.)`.
64[4] Odell, P. L. & Feiveson, A. H., 1966. `A Numerical Procedure to Generate a SampleCovariance Matrix`. JASA, 61(313):199-203.
65[5] Ku, Y.-C. & Bloomfield, P., 2010. `Generating Random Wishart Matrices with Fractional Degrees of Freedom in OX`.
66"""
67arg_constraints = {68"covariance_matrix": constraints.positive_definite,69"precision_matrix": constraints.positive_definite,70"scale_tril": constraints.lower_cholesky,71"df": constraints.greater_than(0),72}73support = constraints.positive_definite74has_rsample = True75_mean_carrier_measure = 076
77def __init__(78self,79df: Union[torch.Tensor, Number],80covariance_matrix: Optional[torch.Tensor] = None,81precision_matrix: Optional[torch.Tensor] = None,82scale_tril: Optional[torch.Tensor] = None,83validate_args=None,84):85assert (covariance_matrix is not None) + (scale_tril is not None) + (86precision_matrix is not None87) == 1, "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified."88
89param = next(90p
91for p in (covariance_matrix, precision_matrix, scale_tril)92if p is not None93)94
95if param.dim() < 2:96raise ValueError(97"scale_tril must be at least two-dimensional, with optional leading batch dimensions"98)99
100if isinstance(df, Number):101batch_shape = torch.Size(param.shape[:-2])102self.df = torch.tensor(df, dtype=param.dtype, device=param.device)103else:104batch_shape = torch.broadcast_shapes(param.shape[:-2], df.shape)105self.df = df.expand(batch_shape)106event_shape = param.shape[-2:]107
108if self.df.le(event_shape[-1] - 1).any():109raise ValueError(110f"Value of df={df} expected to be greater than ndim - 1 = {event_shape[-1]-1}."111)112
113if scale_tril is not None:114self.scale_tril = param.expand(batch_shape + (-1, -1))115elif covariance_matrix is not None:116self.covariance_matrix = param.expand(batch_shape + (-1, -1))117elif precision_matrix is not None:118self.precision_matrix = param.expand(batch_shape + (-1, -1))119
120self.arg_constraints["df"] = constraints.greater_than(event_shape[-1] - 1)121if self.df.lt(event_shape[-1]).any():122warnings.warn(123"Low df values detected. Singular samples are highly likely to occur for ndim - 1 < df < ndim."124)125
126super().__init__(batch_shape, event_shape, validate_args=validate_args)127self._batch_dims = [-(x + 1) for x in range(len(self._batch_shape))]128
129if scale_tril is not None:130self._unbroadcasted_scale_tril = scale_tril131elif covariance_matrix is not None:132self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix)133else: # precision_matrix is not None134self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix)135
136# Chi2 distribution is needed for Bartlett decomposition sampling137self._dist_chi2 = torch.distributions.chi2.Chi2(138df=(139self.df.unsqueeze(-1)140- torch.arange(141self._event_shape[-1],142dtype=self._unbroadcasted_scale_tril.dtype,143device=self._unbroadcasted_scale_tril.device,144).expand(batch_shape + (-1,))145)146)147
148def expand(self, batch_shape, _instance=None):149new = self._get_checked_instance(Wishart, _instance)150batch_shape = torch.Size(batch_shape)151cov_shape = batch_shape + self.event_shape152new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril.expand(cov_shape)153new.df = self.df.expand(batch_shape)154
155new._batch_dims = [-(x + 1) for x in range(len(batch_shape))]156
157if "covariance_matrix" in self.__dict__:158new.covariance_matrix = self.covariance_matrix.expand(cov_shape)159if "scale_tril" in self.__dict__:160new.scale_tril = self.scale_tril.expand(cov_shape)161if "precision_matrix" in self.__dict__:162new.precision_matrix = self.precision_matrix.expand(cov_shape)163
164# Chi2 distribution is needed for Bartlett decomposition sampling165new._dist_chi2 = torch.distributions.chi2.Chi2(166df=(167new.df.unsqueeze(-1)168- torch.arange(169self.event_shape[-1],170dtype=new._unbroadcasted_scale_tril.dtype,171device=new._unbroadcasted_scale_tril.device,172).expand(batch_shape + (-1,))173)174)175
176super(Wishart, new).__init__(batch_shape, self.event_shape, validate_args=False)177new._validate_args = self._validate_args178return new179
180@lazy_property181def scale_tril(self):182return self._unbroadcasted_scale_tril.expand(183self._batch_shape + self._event_shape184)185
186@lazy_property187def covariance_matrix(self):188return (189self._unbroadcasted_scale_tril190@ self._unbroadcasted_scale_tril.transpose(-2, -1)191).expand(self._batch_shape + self._event_shape)192
193@lazy_property194def precision_matrix(self):195identity = torch.eye(196self._event_shape[-1],197device=self._unbroadcasted_scale_tril.device,198dtype=self._unbroadcasted_scale_tril.dtype,199)200return torch.cholesky_solve(identity, self._unbroadcasted_scale_tril).expand(201self._batch_shape + self._event_shape202)203
204@property205def mean(self):206return self.df.view(self._batch_shape + (1, 1)) * self.covariance_matrix207
208@property209def mode(self):210factor = self.df - self.covariance_matrix.shape[-1] - 1211factor[factor <= 0] = nan212return factor.view(self._batch_shape + (1, 1)) * self.covariance_matrix213
214@property215def variance(self):216V = self.covariance_matrix # has shape (batch_shape x event_shape)217diag_V = V.diagonal(dim1=-2, dim2=-1)218return self.df.view(self._batch_shape + (1, 1)) * (219V.pow(2) + torch.einsum("...i,...j->...ij", diag_V, diag_V)220)221
222def _bartlett_sampling(self, sample_shape=torch.Size()):223p = self._event_shape[-1] # has singleton shape224
225# Implemented Sampling using Bartlett decomposition226noise = _clamp_above_eps(227self._dist_chi2.rsample(sample_shape).sqrt()228).diag_embed(dim1=-2, dim2=-1)229
230i, j = torch.tril_indices(p, p, offset=-1)231noise[..., i, j] = torch.randn(232torch.Size(sample_shape) + self._batch_shape + (int(p * (p - 1) / 2),),233dtype=noise.dtype,234device=noise.device,235)236chol = self._unbroadcasted_scale_tril @ noise237return chol @ chol.transpose(-2, -1)238
239def rsample(240self, sample_shape: _size = torch.Size(), max_try_correction=None241) -> torch.Tensor:242r"""243.. warning::
244In some cases, sampling algorithm based on Bartlett decomposition may return singular matrix samples.
245Several tries to correct singular samples are performed by default, but it may end up returning
246singular matrix samples. Singular samples may return `-inf` values in `.log_prob()`.
247In those cases, the user should validate the samples and either fix the value of `df`
248or adjust `max_try_correction` value for argument in `.rsample` accordingly.
249"""
250
251if max_try_correction is None:252max_try_correction = 3 if torch._C._get_tracing_state() else 10253
254sample_shape = torch.Size(sample_shape)255sample = self._bartlett_sampling(sample_shape)256
257# Below part is to improve numerical stability temporally and should be removed in the future258is_singular = self.support.check(sample)259if self._batch_shape:260is_singular = is_singular.amax(self._batch_dims)261
262if torch._C._get_tracing_state():263# Less optimized version for JIT264for _ in range(max_try_correction):265sample_new = self._bartlett_sampling(sample_shape)266sample = torch.where(is_singular, sample_new, sample)267
268is_singular = ~self.support.check(sample)269if self._batch_shape:270is_singular = is_singular.amax(self._batch_dims)271
272else:273# More optimized version with data-dependent control flow.274if is_singular.any():275warnings.warn("Singular sample detected.")276
277for _ in range(max_try_correction):278sample_new = self._bartlett_sampling(is_singular[is_singular].shape)279sample[is_singular] = sample_new280
281is_singular_new = ~self.support.check(sample_new)282if self._batch_shape:283is_singular_new = is_singular_new.amax(self._batch_dims)284is_singular[is_singular.clone()] = is_singular_new285
286if not is_singular.any():287break288
289return sample290
291def log_prob(self, value):292if self._validate_args:293self._validate_sample(value)294nu = self.df # has shape (batch_shape)295p = self._event_shape[-1] # has singleton shape296return (297-nu298* (299p * _log_2 / 2300+ self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1)301.log()302.sum(-1)303)304- torch.mvlgamma(nu / 2, p=p)305+ (nu - p - 1) / 2 * torch.linalg.slogdet(value).logabsdet306- torch.cholesky_solve(value, self._unbroadcasted_scale_tril)307.diagonal(dim1=-2, dim2=-1)308.sum(dim=-1)309/ 2310)311
312def entropy(self):313nu = self.df # has shape (batch_shape)314p = self._event_shape[-1] # has singleton shape315V = self.covariance_matrix # has shape (batch_shape x event_shape)316return (317(p + 1)318* (319p * _log_2 / 2320+ self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1)321.log()322.sum(-1)323)324+ torch.mvlgamma(nu / 2, p=p)325- (nu - p - 1) / 2 * _mvdigamma(nu / 2, p=p)326+ nu * p / 2327)328
329@property330def _natural_params(self):331nu = self.df # has shape (batch_shape)332p = self._event_shape[-1] # has singleton shape333return -self.precision_matrix / 2, (nu - p - 1) / 2334
335def _log_normalizer(self, x, y):336p = self._event_shape[-1]337return (y + (p + 1) / 2) * (338-torch.linalg.slogdet(-2 * x).logabsdet + _log_2 * p339) + torch.mvlgamma(y + (p + 1) / 2, p=p)340