pytorch

Форк
0
/
transformed_distribution.py 
215 строк · 8.5 Кб
1
from typing import Dict
2

3
import torch
4
from torch.distributions import constraints
5
from torch.distributions.distribution import Distribution
6
from torch.distributions.independent import Independent
7
from torch.distributions.transforms import ComposeTransform, Transform
8
from torch.distributions.utils import _sum_rightmost
9

10
__all__ = ["TransformedDistribution"]
11

12

13
class TransformedDistribution(Distribution):
14
    r"""
15
    Extension of the Distribution class, which applies a sequence of Transforms
16
    to a base distribution.  Let f be the composition of transforms applied::
17

18
        X ~ BaseDistribution
19
        Y = f(X) ~ TransformedDistribution(BaseDistribution, f)
20
        log p(Y) = log p(X) + log |det (dX/dY)|
21

22
    Note that the ``.event_shape`` of a :class:`TransformedDistribution` is the
23
    maximum shape of its base distribution and its transforms, since transforms
24
    can introduce correlations among events.
25

26
    An example for the usage of :class:`TransformedDistribution` would be::
27

28
        # Building a Logistic Distribution
29
        # X ~ Uniform(0, 1)
30
        # f = a + b * logit(X)
31
        # Y ~ f(X) ~ Logistic(a, b)
32
        base_distribution = Uniform(0, 1)
33
        transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)]
34
        logistic = TransformedDistribution(base_distribution, transforms)
35

36
    For more examples, please look at the implementations of
37
    :class:`~torch.distributions.gumbel.Gumbel`,
38
    :class:`~torch.distributions.half_cauchy.HalfCauchy`,
39
    :class:`~torch.distributions.half_normal.HalfNormal`,
40
    :class:`~torch.distributions.log_normal.LogNormal`,
41
    :class:`~torch.distributions.pareto.Pareto`,
42
    :class:`~torch.distributions.weibull.Weibull`,
43
    :class:`~torch.distributions.relaxed_bernoulli.RelaxedBernoulli` and
44
    :class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical`
45
    """
46
    arg_constraints: Dict[str, constraints.Constraint] = {}
47

48
    def __init__(self, base_distribution, transforms, validate_args=None):
49
        if isinstance(transforms, Transform):
50
            self.transforms = [
51
                transforms,
52
            ]
53
        elif isinstance(transforms, list):
54
            if not all(isinstance(t, Transform) for t in transforms):
55
                raise ValueError(
56
                    "transforms must be a Transform or a list of Transforms"
57
                )
58
            self.transforms = transforms
59
        else:
60
            raise ValueError(
61
                f"transforms must be a Transform or list, but was {transforms}"
62
            )
63

64
        # Reshape base_distribution according to transforms.
65
        base_shape = base_distribution.batch_shape + base_distribution.event_shape
66
        base_event_dim = len(base_distribution.event_shape)
67
        transform = ComposeTransform(self.transforms)
68
        if len(base_shape) < transform.domain.event_dim:
69
            raise ValueError(
70
                "base_distribution needs to have shape with size at least {}, but got {}.".format(
71
                    transform.domain.event_dim, base_shape
72
                )
73
            )
74
        forward_shape = transform.forward_shape(base_shape)
75
        expanded_base_shape = transform.inverse_shape(forward_shape)
76
        if base_shape != expanded_base_shape:
77
            base_batch_shape = expanded_base_shape[
78
                : len(expanded_base_shape) - base_event_dim
79
            ]
80
            base_distribution = base_distribution.expand(base_batch_shape)
81
        reinterpreted_batch_ndims = transform.domain.event_dim - base_event_dim
82
        if reinterpreted_batch_ndims > 0:
83
            base_distribution = Independent(
84
                base_distribution, reinterpreted_batch_ndims
85
            )
86
        self.base_dist = base_distribution
87

88
        # Compute shapes.
89
        transform_change_in_event_dim = (
90
            transform.codomain.event_dim - transform.domain.event_dim
91
        )
92
        event_dim = max(
93
            transform.codomain.event_dim,  # the transform is coupled
94
            base_event_dim + transform_change_in_event_dim,  # the base dist is coupled
95
        )
96
        assert len(forward_shape) >= event_dim
97
        cut = len(forward_shape) - event_dim
98
        batch_shape = forward_shape[:cut]
99
        event_shape = forward_shape[cut:]
100
        super().__init__(batch_shape, event_shape, validate_args=validate_args)
101

102
    def expand(self, batch_shape, _instance=None):
103
        new = self._get_checked_instance(TransformedDistribution, _instance)
104
        batch_shape = torch.Size(batch_shape)
105
        shape = batch_shape + self.event_shape
106
        for t in reversed(self.transforms):
107
            shape = t.inverse_shape(shape)
108
        base_batch_shape = shape[: len(shape) - len(self.base_dist.event_shape)]
109
        new.base_dist = self.base_dist.expand(base_batch_shape)
110
        new.transforms = self.transforms
111
        super(TransformedDistribution, new).__init__(
112
            batch_shape, self.event_shape, validate_args=False
113
        )
114
        new._validate_args = self._validate_args
115
        return new
116

117
    @constraints.dependent_property(is_discrete=False)
118
    def support(self):
119
        if not self.transforms:
120
            return self.base_dist.support
121
        support = self.transforms[-1].codomain
122
        if len(self.event_shape) > support.event_dim:
123
            support = constraints.independent(
124
                support, len(self.event_shape) - support.event_dim
125
            )
126
        return support
127

128
    @property
129
    def has_rsample(self):
130
        return self.base_dist.has_rsample
131

132
    def sample(self, sample_shape=torch.Size()):
133
        """
134
        Generates a sample_shape shaped sample or sample_shape shaped batch of
135
        samples if the distribution parameters are batched. Samples first from
136
        base distribution and applies `transform()` for every transform in the
137
        list.
138
        """
139
        with torch.no_grad():
140
            x = self.base_dist.sample(sample_shape)
141
            for transform in self.transforms:
142
                x = transform(x)
143
            return x
144

145
    def rsample(self, sample_shape=torch.Size()):
146
        """
147
        Generates a sample_shape shaped reparameterized sample or sample_shape
148
        shaped batch of reparameterized samples if the distribution parameters
149
        are batched. Samples first from base distribution and applies
150
        `transform()` for every transform in the list.
151
        """
152
        x = self.base_dist.rsample(sample_shape)
153
        for transform in self.transforms:
154
            x = transform(x)
155
        return x
156

157
    def log_prob(self, value):
158
        """
159
        Scores the sample by inverting the transform(s) and computing the score
160
        using the score of the base distribution and the log abs det jacobian.
161
        """
162
        if self._validate_args:
163
            self._validate_sample(value)
164
        event_dim = len(self.event_shape)
165
        log_prob = 0.0
166
        y = value
167
        for transform in reversed(self.transforms):
168
            x = transform.inv(y)
169
            event_dim += transform.domain.event_dim - transform.codomain.event_dim
170
            log_prob = log_prob - _sum_rightmost(
171
                transform.log_abs_det_jacobian(x, y),
172
                event_dim - transform.domain.event_dim,
173
            )
174
            y = x
175

176
        log_prob = log_prob + _sum_rightmost(
177
            self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)
178
        )
179
        return log_prob
180

181
    def _monotonize_cdf(self, value):
182
        """
183
        This conditionally flips ``value -> 1-value`` to ensure :meth:`cdf` is
184
        monotone increasing.
185
        """
186
        sign = 1
187
        for transform in self.transforms:
188
            sign = sign * transform.sign
189
        if isinstance(sign, int) and sign == 1:
190
            return value
191
        return sign * (value - 0.5) + 0.5
192

193
    def cdf(self, value):
194
        """
195
        Computes the cumulative distribution function by inverting the
196
        transform(s) and computing the score of the base distribution.
197
        """
198
        for transform in self.transforms[::-1]:
199
            value = transform.inv(value)
200
        if self._validate_args:
201
            self.base_dist._validate_sample(value)
202
        value = self.base_dist.cdf(value)
203
        value = self._monotonize_cdf(value)
204
        return value
205

206
    def icdf(self, value):
207
        """
208
        Computes the inverse cumulative distribution function using
209
        transform(s) and computing the score of the base distribution.
210
        """
211
        value = self._monotonize_cdf(value)
212
        value = self.base_dist.icdf(value)
213
        for transform in self.transforms:
214
            value = transform(value)
215
        return value
216

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.