pytorch
294 строки · 10.1 Кб
1# mypy: allow-untyped-defs
2r"""
3PyTorch provides two global :class:`ConstraintRegistry` objects that link
4:class:`~torch.distributions.constraints.Constraint` objects to
5:class:`~torch.distributions.transforms.Transform` objects. These objects both
6input constraints and return transforms, but they have different guarantees on
7bijectivity.
8
91. ``biject_to(constraint)`` looks up a bijective
10:class:`~torch.distributions.transforms.Transform` from ``constraints.real``
11to the given ``constraint``. The returned transform is guaranteed to have
12``.bijective = True`` and should implement ``.log_abs_det_jacobian()``.
132. ``transform_to(constraint)`` looks up a not-necessarily bijective
14:class:`~torch.distributions.transforms.Transform` from ``constraints.real``
15to the given ``constraint``. The returned transform is not guaranteed to
16implement ``.log_abs_det_jacobian()``.
17
18The ``transform_to()`` registry is useful for performing unconstrained
19optimization on constrained parameters of probability distributions, which are
20indicated by each distribution's ``.arg_constraints`` dict. These transforms often
21overparameterize a space in order to avoid rotation; they are thus more
22suitable for coordinate-wise optimization algorithms like Adam::
23
24loc = torch.zeros(100, requires_grad=True)
25unconstrained = torch.zeros(100, requires_grad=True)
26scale = transform_to(Normal.arg_constraints['scale'])(unconstrained)
27loss = -Normal(loc, scale).log_prob(data).sum()
28
29The ``biject_to()`` registry is useful for Hamiltonian Monte Carlo, where
30samples from a probability distribution with constrained ``.support`` are
31propagated in an unconstrained space, and algorithms are typically rotation
32invariant.::
33
34dist = Exponential(rate)
35unconstrained = torch.zeros(100, requires_grad=True)
36sample = biject_to(dist.support)(unconstrained)
37potential_energy = -dist.log_prob(sample).sum()
38
39.. note::
40
41An example where ``transform_to`` and ``biject_to`` differ is
42``constraints.simplex``: ``transform_to(constraints.simplex)`` returns a
43:class:`~torch.distributions.transforms.SoftmaxTransform` that simply
44exponentiates and normalizes its inputs; this is a cheap and mostly
45coordinate-wise operation appropriate for algorithms like SVI. In
46contrast, ``biject_to(constraints.simplex)`` returns a
47:class:`~torch.distributions.transforms.StickBreakingTransform` that
48bijects its input down to a one-fewer-dimensional space; this a more
49expensive less numerically stable transform but is needed for algorithms
50like HMC.
51
52The ``biject_to`` and ``transform_to`` objects can be extended by user-defined
53constraints and transforms using their ``.register()`` method either as a
54function on singleton constraints::
55
56transform_to.register(my_constraint, my_transform)
57
58or as a decorator on parameterized constraints::
59
60@transform_to.register(MyConstraintClass)
61def my_factory(constraint):
62assert isinstance(constraint, MyConstraintClass)
63return MyTransform(constraint.param1, constraint.param2)
64
65You can create your own registry by creating a new :class:`ConstraintRegistry`
66object.
67"""
68
69import numbers70
71from torch.distributions import constraints, transforms72
73
74__all__ = [75"ConstraintRegistry",76"biject_to",77"transform_to",78]
79
80
81class ConstraintRegistry:82"""83Registry to link constraints to transforms.
84"""
85
86def __init__(self):87self._registry = {}88super().__init__()89
90def register(self, constraint, factory=None):91"""92Registers a :class:`~torch.distributions.constraints.Constraint`
93subclass in this registry. Usage::
94
95@my_registry.register(MyConstraintClass)
96def construct_transform(constraint):
97assert isinstance(constraint, MyConstraint)
98return MyTransform(constraint.arg_constraints)
99
100Args:
101constraint (subclass of :class:`~torch.distributions.constraints.Constraint`):
102A subclass of :class:`~torch.distributions.constraints.Constraint`, or
103a singleton object of the desired class.
104factory (Callable): A callable that inputs a constraint object and returns
105a :class:`~torch.distributions.transforms.Transform` object.
106"""
107# Support use as decorator.108if factory is None:109return lambda factory: self.register(constraint, factory)110
111# Support calling on singleton instances.112if isinstance(constraint, constraints.Constraint):113constraint = type(constraint)114
115if not isinstance(constraint, type) or not issubclass(116constraint, constraints.Constraint117):118raise TypeError(119f"Expected constraint to be either a Constraint subclass or instance, but got {constraint}"120)121
122self._registry[constraint] = factory123return factory124
125def __call__(self, constraint):126"""127Looks up a transform to constrained space, given a constraint object.
128Usage::
129
130constraint = Normal.arg_constraints['scale']
131scale = transform_to(constraint)(torch.zeros(1)) # constrained
132u = transform_to(constraint).inv(scale) # unconstrained
133
134Args:
135constraint (:class:`~torch.distributions.constraints.Constraint`):
136A constraint object.
137
138Returns:
139A :class:`~torch.distributions.transforms.Transform` object.
140
141Raises:
142`NotImplementedError` if no transform has been registered.
143"""
144# Look up by Constraint subclass.145try:146factory = self._registry[type(constraint)]147except KeyError:148raise NotImplementedError(149f"Cannot transform {type(constraint).__name__} constraints"150) from None151return factory(constraint)152
153
154biject_to = ConstraintRegistry()155transform_to = ConstraintRegistry()156
157
158################################################################################
159# Registration Table
160################################################################################
161
162
163@biject_to.register(constraints.real)164@transform_to.register(constraints.real)165def _transform_to_real(constraint):166return transforms.identity_transform167
168
169@biject_to.register(constraints.independent)170def _biject_to_independent(constraint):171base_transform = biject_to(constraint.base_constraint)172return transforms.IndependentTransform(173base_transform, constraint.reinterpreted_batch_ndims174)175
176
177@transform_to.register(constraints.independent)178def _transform_to_independent(constraint):179base_transform = transform_to(constraint.base_constraint)180return transforms.IndependentTransform(181base_transform, constraint.reinterpreted_batch_ndims182)183
184
185@biject_to.register(constraints.positive)186@biject_to.register(constraints.nonnegative)187@transform_to.register(constraints.positive)188@transform_to.register(constraints.nonnegative)189def _transform_to_positive(constraint):190return transforms.ExpTransform()191
192
193@biject_to.register(constraints.greater_than)194@biject_to.register(constraints.greater_than_eq)195@transform_to.register(constraints.greater_than)196@transform_to.register(constraints.greater_than_eq)197def _transform_to_greater_than(constraint):198return transforms.ComposeTransform(199[200transforms.ExpTransform(),201transforms.AffineTransform(constraint.lower_bound, 1),202]203)204
205
206@biject_to.register(constraints.less_than)207@transform_to.register(constraints.less_than)208def _transform_to_less_than(constraint):209return transforms.ComposeTransform(210[211transforms.ExpTransform(),212transforms.AffineTransform(constraint.upper_bound, -1),213]214)215
216
217@biject_to.register(constraints.interval)218@biject_to.register(constraints.half_open_interval)219@transform_to.register(constraints.interval)220@transform_to.register(constraints.half_open_interval)221def _transform_to_interval(constraint):222# Handle the special case of the unit interval.223lower_is_0 = (224isinstance(constraint.lower_bound, numbers.Number)225and constraint.lower_bound == 0226)227upper_is_1 = (228isinstance(constraint.upper_bound, numbers.Number)229and constraint.upper_bound == 1230)231if lower_is_0 and upper_is_1:232return transforms.SigmoidTransform()233
234loc = constraint.lower_bound235scale = constraint.upper_bound - constraint.lower_bound236return transforms.ComposeTransform(237[transforms.SigmoidTransform(), transforms.AffineTransform(loc, scale)]238)239
240
241@biject_to.register(constraints.simplex)242def _biject_to_simplex(constraint):243return transforms.StickBreakingTransform()244
245
246@transform_to.register(constraints.simplex)247def _transform_to_simplex(constraint):248return transforms.SoftmaxTransform()249
250
251# TODO define a bijection for LowerCholeskyTransform
252@transform_to.register(constraints.lower_cholesky)253def _transform_to_lower_cholesky(constraint):254return transforms.LowerCholeskyTransform()255
256
257@transform_to.register(constraints.positive_definite)258@transform_to.register(constraints.positive_semidefinite)259def _transform_to_positive_definite(constraint):260return transforms.PositiveDefiniteTransform()261
262
263@biject_to.register(constraints.corr_cholesky)264@transform_to.register(constraints.corr_cholesky)265def _transform_to_corr_cholesky(constraint):266return transforms.CorrCholeskyTransform()267
268
269@biject_to.register(constraints.cat)270def _biject_to_cat(constraint):271return transforms.CatTransform(272[biject_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths273)274
275
276@transform_to.register(constraints.cat)277def _transform_to_cat(constraint):278return transforms.CatTransform(279[transform_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths280)281
282
283@biject_to.register(constraints.stack)284def _biject_to_stack(constraint):285return transforms.StackTransform(286[biject_to(c) for c in constraint.cseq], constraint.dim287)288
289
290@transform_to.register(constraints.stack)291def _transform_to_stack(constraint):292return transforms.StackTransform(293[transform_to(c) for c in constraint.cseq], constraint.dim294)295