pytorch
215 строк · 8.5 Кб
1from typing import Dict
2
3import torch
4from torch.distributions import constraints
5from torch.distributions.distribution import Distribution
6from torch.distributions.independent import Independent
7from torch.distributions.transforms import ComposeTransform, Transform
8from torch.distributions.utils import _sum_rightmost
9
10__all__ = ["TransformedDistribution"]
11
12
13class TransformedDistribution(Distribution):
14r"""
15Extension of the Distribution class, which applies a sequence of Transforms
16to a base distribution. Let f be the composition of transforms applied::
17
18X ~ BaseDistribution
19Y = f(X) ~ TransformedDistribution(BaseDistribution, f)
20log p(Y) = log p(X) + log |det (dX/dY)|
21
22Note that the ``.event_shape`` of a :class:`TransformedDistribution` is the
23maximum shape of its base distribution and its transforms, since transforms
24can introduce correlations among events.
25
26An example for the usage of :class:`TransformedDistribution` would be::
27
28# Building a Logistic Distribution
29# X ~ Uniform(0, 1)
30# f = a + b * logit(X)
31# Y ~ f(X) ~ Logistic(a, b)
32base_distribution = Uniform(0, 1)
33transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)]
34logistic = TransformedDistribution(base_distribution, transforms)
35
36For more examples, please look at the implementations of
37:class:`~torch.distributions.gumbel.Gumbel`,
38:class:`~torch.distributions.half_cauchy.HalfCauchy`,
39:class:`~torch.distributions.half_normal.HalfNormal`,
40:class:`~torch.distributions.log_normal.LogNormal`,
41:class:`~torch.distributions.pareto.Pareto`,
42:class:`~torch.distributions.weibull.Weibull`,
43:class:`~torch.distributions.relaxed_bernoulli.RelaxedBernoulli` and
44:class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical`
45"""
46arg_constraints: Dict[str, constraints.Constraint] = {}
47
48def __init__(self, base_distribution, transforms, validate_args=None):
49if isinstance(transforms, Transform):
50self.transforms = [
51transforms,
52]
53elif isinstance(transforms, list):
54if not all(isinstance(t, Transform) for t in transforms):
55raise ValueError(
56"transforms must be a Transform or a list of Transforms"
57)
58self.transforms = transforms
59else:
60raise ValueError(
61f"transforms must be a Transform or list, but was {transforms}"
62)
63
64# Reshape base_distribution according to transforms.
65base_shape = base_distribution.batch_shape + base_distribution.event_shape
66base_event_dim = len(base_distribution.event_shape)
67transform = ComposeTransform(self.transforms)
68if len(base_shape) < transform.domain.event_dim:
69raise ValueError(
70"base_distribution needs to have shape with size at least {}, but got {}.".format(
71transform.domain.event_dim, base_shape
72)
73)
74forward_shape = transform.forward_shape(base_shape)
75expanded_base_shape = transform.inverse_shape(forward_shape)
76if base_shape != expanded_base_shape:
77base_batch_shape = expanded_base_shape[
78: len(expanded_base_shape) - base_event_dim
79]
80base_distribution = base_distribution.expand(base_batch_shape)
81reinterpreted_batch_ndims = transform.domain.event_dim - base_event_dim
82if reinterpreted_batch_ndims > 0:
83base_distribution = Independent(
84base_distribution, reinterpreted_batch_ndims
85)
86self.base_dist = base_distribution
87
88# Compute shapes.
89transform_change_in_event_dim = (
90transform.codomain.event_dim - transform.domain.event_dim
91)
92event_dim = max(
93transform.codomain.event_dim, # the transform is coupled
94base_event_dim + transform_change_in_event_dim, # the base dist is coupled
95)
96assert len(forward_shape) >= event_dim
97cut = len(forward_shape) - event_dim
98batch_shape = forward_shape[:cut]
99event_shape = forward_shape[cut:]
100super().__init__(batch_shape, event_shape, validate_args=validate_args)
101
102def expand(self, batch_shape, _instance=None):
103new = self._get_checked_instance(TransformedDistribution, _instance)
104batch_shape = torch.Size(batch_shape)
105shape = batch_shape + self.event_shape
106for t in reversed(self.transforms):
107shape = t.inverse_shape(shape)
108base_batch_shape = shape[: len(shape) - len(self.base_dist.event_shape)]
109new.base_dist = self.base_dist.expand(base_batch_shape)
110new.transforms = self.transforms
111super(TransformedDistribution, new).__init__(
112batch_shape, self.event_shape, validate_args=False
113)
114new._validate_args = self._validate_args
115return new
116
117@constraints.dependent_property(is_discrete=False)
118def support(self):
119if not self.transforms:
120return self.base_dist.support
121support = self.transforms[-1].codomain
122if len(self.event_shape) > support.event_dim:
123support = constraints.independent(
124support, len(self.event_shape) - support.event_dim
125)
126return support
127
128@property
129def has_rsample(self):
130return self.base_dist.has_rsample
131
132def sample(self, sample_shape=torch.Size()):
133"""
134Generates a sample_shape shaped sample or sample_shape shaped batch of
135samples if the distribution parameters are batched. Samples first from
136base distribution and applies `transform()` for every transform in the
137list.
138"""
139with torch.no_grad():
140x = self.base_dist.sample(sample_shape)
141for transform in self.transforms:
142x = transform(x)
143return x
144
145def rsample(self, sample_shape=torch.Size()):
146"""
147Generates a sample_shape shaped reparameterized sample or sample_shape
148shaped batch of reparameterized samples if the distribution parameters
149are batched. Samples first from base distribution and applies
150`transform()` for every transform in the list.
151"""
152x = self.base_dist.rsample(sample_shape)
153for transform in self.transforms:
154x = transform(x)
155return x
156
157def log_prob(self, value):
158"""
159Scores the sample by inverting the transform(s) and computing the score
160using the score of the base distribution and the log abs det jacobian.
161"""
162if self._validate_args:
163self._validate_sample(value)
164event_dim = len(self.event_shape)
165log_prob = 0.0
166y = value
167for transform in reversed(self.transforms):
168x = transform.inv(y)
169event_dim += transform.domain.event_dim - transform.codomain.event_dim
170log_prob = log_prob - _sum_rightmost(
171transform.log_abs_det_jacobian(x, y),
172event_dim - transform.domain.event_dim,
173)
174y = x
175
176log_prob = log_prob + _sum_rightmost(
177self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)
178)
179return log_prob
180
181def _monotonize_cdf(self, value):
182"""
183This conditionally flips ``value -> 1-value`` to ensure :meth:`cdf` is
184monotone increasing.
185"""
186sign = 1
187for transform in self.transforms:
188sign = sign * transform.sign
189if isinstance(sign, int) and sign == 1:
190return value
191return sign * (value - 0.5) + 0.5
192
193def cdf(self, value):
194"""
195Computes the cumulative distribution function by inverting the
196transform(s) and computing the score of the base distribution.
197"""
198for transform in self.transforms[::-1]:
199value = transform.inv(value)
200if self._validate_args:
201self.base_dist._validate_sample(value)
202value = self.base_dist.cdf(value)
203value = self._monotonize_cdf(value)
204return value
205
206def icdf(self, value):
207"""
208Computes the inverse cumulative distribution function using
209transform(s) and computing the score of the base distribution.
210"""
211value = self._monotonize_cdf(value)
212value = self.base_dist.icdf(value)
213for transform in self.transforms:
214value = transform(value)
215return value
216