pytorch

Форк
0
/
constraint_registry.py 
294 строки · 10.1 Кб
1
# mypy: allow-untyped-defs
2
r"""
3
PyTorch provides two global :class:`ConstraintRegistry` objects that link
4
:class:`~torch.distributions.constraints.Constraint` objects to
5
:class:`~torch.distributions.transforms.Transform` objects. These objects both
6
input constraints and return transforms, but they have different guarantees on
7
bijectivity.
8

9
1. ``biject_to(constraint)`` looks up a bijective
10
   :class:`~torch.distributions.transforms.Transform` from ``constraints.real``
11
   to the given ``constraint``. The returned transform is guaranteed to have
12
   ``.bijective = True`` and should implement ``.log_abs_det_jacobian()``.
13
2. ``transform_to(constraint)`` looks up a not-necessarily bijective
14
   :class:`~torch.distributions.transforms.Transform` from ``constraints.real``
15
   to the given ``constraint``. The returned transform is not guaranteed to
16
   implement ``.log_abs_det_jacobian()``.
17

18
The ``transform_to()`` registry is useful for performing unconstrained
19
optimization on constrained parameters of probability distributions, which are
20
indicated by each distribution's ``.arg_constraints`` dict. These transforms often
21
overparameterize a space in order to avoid rotation; they are thus more
22
suitable for coordinate-wise optimization algorithms like Adam::
23

24
    loc = torch.zeros(100, requires_grad=True)
25
    unconstrained = torch.zeros(100, requires_grad=True)
26
    scale = transform_to(Normal.arg_constraints['scale'])(unconstrained)
27
    loss = -Normal(loc, scale).log_prob(data).sum()
28

29
The ``biject_to()`` registry is useful for Hamiltonian Monte Carlo, where
30
samples from a probability distribution with constrained ``.support`` are
31
propagated in an unconstrained space, and algorithms are typically rotation
32
invariant.::
33

34
    dist = Exponential(rate)
35
    unconstrained = torch.zeros(100, requires_grad=True)
36
    sample = biject_to(dist.support)(unconstrained)
37
    potential_energy = -dist.log_prob(sample).sum()
38

39
.. note::
40

41
    An example where ``transform_to`` and ``biject_to`` differ is
42
    ``constraints.simplex``: ``transform_to(constraints.simplex)`` returns a
43
    :class:`~torch.distributions.transforms.SoftmaxTransform` that simply
44
    exponentiates and normalizes its inputs; this is a cheap and mostly
45
    coordinate-wise operation appropriate for algorithms like SVI. In
46
    contrast, ``biject_to(constraints.simplex)`` returns a
47
    :class:`~torch.distributions.transforms.StickBreakingTransform` that
48
    bijects its input down to a one-fewer-dimensional space; this a more
49
    expensive less numerically stable transform but is needed for algorithms
50
    like HMC.
51

52
The ``biject_to`` and ``transform_to`` objects can be extended by user-defined
53
constraints and transforms using their ``.register()`` method either as a
54
function on singleton constraints::
55

56
    transform_to.register(my_constraint, my_transform)
57

58
or as a decorator on parameterized constraints::
59

60
    @transform_to.register(MyConstraintClass)
61
    def my_factory(constraint):
62
        assert isinstance(constraint, MyConstraintClass)
63
        return MyTransform(constraint.param1, constraint.param2)
64

65
You can create your own registry by creating a new :class:`ConstraintRegistry`
66
object.
67
"""
68

69
import numbers
70

71
from torch.distributions import constraints, transforms
72

73

74
__all__ = [
75
    "ConstraintRegistry",
76
    "biject_to",
77
    "transform_to",
78
]
79

80

81
class ConstraintRegistry:
82
    """
83
    Registry to link constraints to transforms.
84
    """
85

86
    def __init__(self):
87
        self._registry = {}
88
        super().__init__()
89

90
    def register(self, constraint, factory=None):
91
        """
92
        Registers a :class:`~torch.distributions.constraints.Constraint`
93
        subclass in this registry. Usage::
94

95
            @my_registry.register(MyConstraintClass)
96
            def construct_transform(constraint):
97
                assert isinstance(constraint, MyConstraint)
98
                return MyTransform(constraint.arg_constraints)
99

100
        Args:
101
            constraint (subclass of :class:`~torch.distributions.constraints.Constraint`):
102
                A subclass of :class:`~torch.distributions.constraints.Constraint`, or
103
                a singleton object of the desired class.
104
            factory (Callable): A callable that inputs a constraint object and returns
105
                a  :class:`~torch.distributions.transforms.Transform` object.
106
        """
107
        # Support use as decorator.
108
        if factory is None:
109
            return lambda factory: self.register(constraint, factory)
110

111
        # Support calling on singleton instances.
112
        if isinstance(constraint, constraints.Constraint):
113
            constraint = type(constraint)
114

115
        if not isinstance(constraint, type) or not issubclass(
116
            constraint, constraints.Constraint
117
        ):
118
            raise TypeError(
119
                f"Expected constraint to be either a Constraint subclass or instance, but got {constraint}"
120
            )
121

122
        self._registry[constraint] = factory
123
        return factory
124

125
    def __call__(self, constraint):
126
        """
127
        Looks up a transform to constrained space, given a constraint object.
128
        Usage::
129

130
            constraint = Normal.arg_constraints['scale']
131
            scale = transform_to(constraint)(torch.zeros(1))  # constrained
132
            u = transform_to(constraint).inv(scale)           # unconstrained
133

134
        Args:
135
            constraint (:class:`~torch.distributions.constraints.Constraint`):
136
                A constraint object.
137

138
        Returns:
139
            A :class:`~torch.distributions.transforms.Transform` object.
140

141
        Raises:
142
            `NotImplementedError` if no transform has been registered.
143
        """
144
        # Look up by Constraint subclass.
145
        try:
146
            factory = self._registry[type(constraint)]
147
        except KeyError:
148
            raise NotImplementedError(
149
                f"Cannot transform {type(constraint).__name__} constraints"
150
            ) from None
151
        return factory(constraint)
152

153

154
biject_to = ConstraintRegistry()
155
transform_to = ConstraintRegistry()
156

157

158
################################################################################
159
# Registration Table
160
################################################################################
161

162

163
@biject_to.register(constraints.real)
164
@transform_to.register(constraints.real)
165
def _transform_to_real(constraint):
166
    return transforms.identity_transform
167

168

169
@biject_to.register(constraints.independent)
170
def _biject_to_independent(constraint):
171
    base_transform = biject_to(constraint.base_constraint)
172
    return transforms.IndependentTransform(
173
        base_transform, constraint.reinterpreted_batch_ndims
174
    )
175

176

177
@transform_to.register(constraints.independent)
178
def _transform_to_independent(constraint):
179
    base_transform = transform_to(constraint.base_constraint)
180
    return transforms.IndependentTransform(
181
        base_transform, constraint.reinterpreted_batch_ndims
182
    )
183

184

185
@biject_to.register(constraints.positive)
186
@biject_to.register(constraints.nonnegative)
187
@transform_to.register(constraints.positive)
188
@transform_to.register(constraints.nonnegative)
189
def _transform_to_positive(constraint):
190
    return transforms.ExpTransform()
191

192

193
@biject_to.register(constraints.greater_than)
194
@biject_to.register(constraints.greater_than_eq)
195
@transform_to.register(constraints.greater_than)
196
@transform_to.register(constraints.greater_than_eq)
197
def _transform_to_greater_than(constraint):
198
    return transforms.ComposeTransform(
199
        [
200
            transforms.ExpTransform(),
201
            transforms.AffineTransform(constraint.lower_bound, 1),
202
        ]
203
    )
204

205

206
@biject_to.register(constraints.less_than)
207
@transform_to.register(constraints.less_than)
208
def _transform_to_less_than(constraint):
209
    return transforms.ComposeTransform(
210
        [
211
            transforms.ExpTransform(),
212
            transforms.AffineTransform(constraint.upper_bound, -1),
213
        ]
214
    )
215

216

217
@biject_to.register(constraints.interval)
218
@biject_to.register(constraints.half_open_interval)
219
@transform_to.register(constraints.interval)
220
@transform_to.register(constraints.half_open_interval)
221
def _transform_to_interval(constraint):
222
    # Handle the special case of the unit interval.
223
    lower_is_0 = (
224
        isinstance(constraint.lower_bound, numbers.Number)
225
        and constraint.lower_bound == 0
226
    )
227
    upper_is_1 = (
228
        isinstance(constraint.upper_bound, numbers.Number)
229
        and constraint.upper_bound == 1
230
    )
231
    if lower_is_0 and upper_is_1:
232
        return transforms.SigmoidTransform()
233

234
    loc = constraint.lower_bound
235
    scale = constraint.upper_bound - constraint.lower_bound
236
    return transforms.ComposeTransform(
237
        [transforms.SigmoidTransform(), transforms.AffineTransform(loc, scale)]
238
    )
239

240

241
@biject_to.register(constraints.simplex)
242
def _biject_to_simplex(constraint):
243
    return transforms.StickBreakingTransform()
244

245

246
@transform_to.register(constraints.simplex)
247
def _transform_to_simplex(constraint):
248
    return transforms.SoftmaxTransform()
249

250

251
# TODO define a bijection for LowerCholeskyTransform
252
@transform_to.register(constraints.lower_cholesky)
253
def _transform_to_lower_cholesky(constraint):
254
    return transforms.LowerCholeskyTransform()
255

256

257
@transform_to.register(constraints.positive_definite)
258
@transform_to.register(constraints.positive_semidefinite)
259
def _transform_to_positive_definite(constraint):
260
    return transforms.PositiveDefiniteTransform()
261

262

263
@biject_to.register(constraints.corr_cholesky)
264
@transform_to.register(constraints.corr_cholesky)
265
def _transform_to_corr_cholesky(constraint):
266
    return transforms.CorrCholeskyTransform()
267

268

269
@biject_to.register(constraints.cat)
270
def _biject_to_cat(constraint):
271
    return transforms.CatTransform(
272
        [biject_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths
273
    )
274

275

276
@transform_to.register(constraints.cat)
277
def _transform_to_cat(constraint):
278
    return transforms.CatTransform(
279
        [transform_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths
280
    )
281

282

283
@biject_to.register(constraints.stack)
284
def _biject_to_stack(constraint):
285
    return transforms.StackTransform(
286
        [biject_to(c) for c in constraint.cseq], constraint.dim
287
    )
288

289

290
@transform_to.register(constraints.stack)
291
def _transform_to_stack(constraint):
292
    return transforms.StackTransform(
293
        [transform_to(c) for c in constraint.cseq], constraint.dim
294
    )
295

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.