pytorch
125 строк · 4.5 Кб
1from typing import Dict2
3import torch4from torch.distributions import constraints5from torch.distributions.distribution import Distribution6from torch.distributions.utils import _sum_rightmost7
8__all__ = ["Independent"]9
10
11class Independent(Distribution):12r"""13Reinterprets some of the batch dims of a distribution as event dims.
14
15This is mainly useful for changing the shape of the result of
16:meth:`log_prob`. For example to create a diagonal Normal distribution with
17the same shape as a Multivariate Normal distribution (so they are
18interchangeable), you can::
19
20>>> from torch.distributions.multivariate_normal import MultivariateNormal
21>>> from torch.distributions.normal import Normal
22>>> loc = torch.zeros(3)
23>>> scale = torch.ones(3)
24>>> mvn = MultivariateNormal(loc, scale_tril=torch.diag(scale))
25>>> [mvn.batch_shape, mvn.event_shape]
26[torch.Size([]), torch.Size([3])]
27>>> normal = Normal(loc, scale)
28>>> [normal.batch_shape, normal.event_shape]
29[torch.Size([3]), torch.Size([])]
30>>> diagn = Independent(normal, 1)
31>>> [diagn.batch_shape, diagn.event_shape]
32[torch.Size([]), torch.Size([3])]
33
34Args:
35base_distribution (torch.distributions.distribution.Distribution): a
36base distribution
37reinterpreted_batch_ndims (int): the number of batch dims to
38reinterpret as event dims
39"""
40arg_constraints: Dict[str, constraints.Constraint] = {}41
42def __init__(43self, base_distribution, reinterpreted_batch_ndims, validate_args=None44):45if reinterpreted_batch_ndims > len(base_distribution.batch_shape):46raise ValueError(47"Expected reinterpreted_batch_ndims <= len(base_distribution.batch_shape), "48f"actual {reinterpreted_batch_ndims} vs {len(base_distribution.batch_shape)}"49)50shape = base_distribution.batch_shape + base_distribution.event_shape51event_dim = reinterpreted_batch_ndims + len(base_distribution.event_shape)52batch_shape = shape[: len(shape) - event_dim]53event_shape = shape[len(shape) - event_dim :]54self.base_dist = base_distribution55self.reinterpreted_batch_ndims = reinterpreted_batch_ndims56super().__init__(batch_shape, event_shape, validate_args=validate_args)57
58def expand(self, batch_shape, _instance=None):59new = self._get_checked_instance(Independent, _instance)60batch_shape = torch.Size(batch_shape)61new.base_dist = self.base_dist.expand(62batch_shape + self.event_shape[: self.reinterpreted_batch_ndims]63)64new.reinterpreted_batch_ndims = self.reinterpreted_batch_ndims65super(Independent, new).__init__(66batch_shape, self.event_shape, validate_args=False67)68new._validate_args = self._validate_args69return new70
71@property72def has_rsample(self):73return self.base_dist.has_rsample74
75@property76def has_enumerate_support(self):77if self.reinterpreted_batch_ndims > 0:78return False79return self.base_dist.has_enumerate_support80
81@constraints.dependent_property82def support(self):83result = self.base_dist.support84if self.reinterpreted_batch_ndims:85result = constraints.independent(result, self.reinterpreted_batch_ndims)86return result87
88@property89def mean(self):90return self.base_dist.mean91
92@property93def mode(self):94return self.base_dist.mode95
96@property97def variance(self):98return self.base_dist.variance99
100def sample(self, sample_shape=torch.Size()):101return self.base_dist.sample(sample_shape)102
103def rsample(self, sample_shape=torch.Size()):104return self.base_dist.rsample(sample_shape)105
106def log_prob(self, value):107log_prob = self.base_dist.log_prob(value)108return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims)109
110def entropy(self):111entropy = self.base_dist.entropy()112return _sum_rightmost(entropy, self.reinterpreted_batch_ndims)113
114def enumerate_support(self, expand=True):115if self.reinterpreted_batch_ndims > 0:116raise NotImplementedError(117"Enumeration over cartesian product is not implemented"118)119return self.base_dist.enumerate_support(expand=expand)120
121def __repr__(self):122return (123self.__class__.__name__124+ f"({self.base_dist}, {self.reinterpreted_batch_ndims})"125)126