pytorch
1245 строк · 40.0 Кб
1import functools
2import math
3import numbers
4import operator
5import weakref
6from typing import List
7
8import torch
9import torch.nn.functional as F
10from torch.distributions import constraints
11from torch.distributions.utils import (
12_sum_rightmost,
13broadcast_all,
14lazy_property,
15tril_matrix_to_vec,
16vec_to_tril_matrix,
17)
18from torch.nn.functional import pad, softplus
19
20__all__ = [
21"AbsTransform",
22"AffineTransform",
23"CatTransform",
24"ComposeTransform",
25"CorrCholeskyTransform",
26"CumulativeDistributionTransform",
27"ExpTransform",
28"IndependentTransform",
29"LowerCholeskyTransform",
30"PositiveDefiniteTransform",
31"PowerTransform",
32"ReshapeTransform",
33"SigmoidTransform",
34"SoftplusTransform",
35"TanhTransform",
36"SoftmaxTransform",
37"StackTransform",
38"StickBreakingTransform",
39"Transform",
40"identity_transform",
41]
42
43
44class Transform:
45"""
46Abstract class for invertable transformations with computable log
47det jacobians. They are primarily used in
48:class:`torch.distributions.TransformedDistribution`.
49
50Caching is useful for transforms whose inverses are either expensive or
51numerically unstable. Note that care must be taken with memoized values
52since the autograd graph may be reversed. For example while the following
53works with or without caching::
54
55y = t(x)
56t.log_abs_det_jacobian(x, y).backward() # x will receive gradients.
57
58However the following will error when caching due to dependency reversal::
59
60y = t(x)
61z = t.inv(y)
62grad(z.sum(), [y]) # error because z is x
63
64Derived classes should implement one or both of :meth:`_call` or
65:meth:`_inverse`. Derived classes that set `bijective=True` should also
66implement :meth:`log_abs_det_jacobian`.
67
68Args:
69cache_size (int): Size of cache. If zero, no caching is done. If one,
70the latest single value is cached. Only 0 and 1 are supported.
71
72Attributes:
73domain (:class:`~torch.distributions.constraints.Constraint`):
74The constraint representing valid inputs to this transform.
75codomain (:class:`~torch.distributions.constraints.Constraint`):
76The constraint representing valid outputs to this transform
77which are inputs to the inverse transform.
78bijective (bool): Whether this transform is bijective. A transform
79``t`` is bijective iff ``t.inv(t(x)) == x`` and
80``t(t.inv(y)) == y`` for every ``x`` in the domain and ``y`` in
81the codomain. Transforms that are not bijective should at least
82maintain the weaker pseudoinverse properties
83``t(t.inv(t(x)) == t(x)`` and ``t.inv(t(t.inv(y))) == t.inv(y)``.
84sign (int or Tensor): For bijective univariate transforms, this
85should be +1 or -1 depending on whether transform is monotone
86increasing or decreasing.
87"""
88
89bijective = False
90domain: constraints.Constraint
91codomain: constraints.Constraint
92
93def __init__(self, cache_size=0):
94self._cache_size = cache_size
95self._inv = None
96if cache_size == 0:
97pass # default behavior
98elif cache_size == 1:
99self._cached_x_y = None, None
100else:
101raise ValueError("cache_size must be 0 or 1")
102super().__init__()
103
104def __getstate__(self):
105state = self.__dict__.copy()
106state["_inv"] = None
107return state
108
109@property
110def event_dim(self):
111if self.domain.event_dim == self.codomain.event_dim:
112return self.domain.event_dim
113raise ValueError("Please use either .domain.event_dim or .codomain.event_dim")
114
115@property
116def inv(self):
117"""
118Returns the inverse :class:`Transform` of this transform.
119This should satisfy ``t.inv.inv is t``.
120"""
121inv = None
122if self._inv is not None:
123inv = self._inv()
124if inv is None:
125inv = _InverseTransform(self)
126self._inv = weakref.ref(inv)
127return inv
128
129@property
130def sign(self):
131"""
132Returns the sign of the determinant of the Jacobian, if applicable.
133In general this only makes sense for bijective transforms.
134"""
135raise NotImplementedError
136
137def with_cache(self, cache_size=1):
138if self._cache_size == cache_size:
139return self
140if type(self).__init__ is Transform.__init__:
141return type(self)(cache_size=cache_size)
142raise NotImplementedError(f"{type(self)}.with_cache is not implemented")
143
144def __eq__(self, other):
145return self is other
146
147def __ne__(self, other):
148# Necessary for Python2
149return not self.__eq__(other)
150
151def __call__(self, x):
152"""
153Computes the transform `x => y`.
154"""
155if self._cache_size == 0:
156return self._call(x)
157x_old, y_old = self._cached_x_y
158if x is x_old:
159return y_old
160y = self._call(x)
161self._cached_x_y = x, y
162return y
163
164def _inv_call(self, y):
165"""
166Inverts the transform `y => x`.
167"""
168if self._cache_size == 0:
169return self._inverse(y)
170x_old, y_old = self._cached_x_y
171if y is y_old:
172return x_old
173x = self._inverse(y)
174self._cached_x_y = x, y
175return x
176
177def _call(self, x):
178"""
179Abstract method to compute forward transformation.
180"""
181raise NotImplementedError
182
183def _inverse(self, y):
184"""
185Abstract method to compute inverse transformation.
186"""
187raise NotImplementedError
188
189def log_abs_det_jacobian(self, x, y):
190"""
191Computes the log det jacobian `log |dy/dx|` given input and output.
192"""
193raise NotImplementedError
194
195def __repr__(self):
196return self.__class__.__name__ + "()"
197
198def forward_shape(self, shape):
199"""
200Infers the shape of the forward computation, given the input shape.
201Defaults to preserving shape.
202"""
203return shape
204
205def inverse_shape(self, shape):
206"""
207Infers the shapes of the inverse computation, given the output shape.
208Defaults to preserving shape.
209"""
210return shape
211
212
213class _InverseTransform(Transform):
214"""
215Inverts a single :class:`Transform`.
216This class is private; please instead use the ``Transform.inv`` property.
217"""
218
219def __init__(self, transform: Transform):
220super().__init__(cache_size=transform._cache_size)
221self._inv: Transform = transform
222
223@constraints.dependent_property(is_discrete=False)
224def domain(self):
225assert self._inv is not None
226return self._inv.codomain
227
228@constraints.dependent_property(is_discrete=False)
229def codomain(self):
230assert self._inv is not None
231return self._inv.domain
232
233@property
234def bijective(self):
235assert self._inv is not None
236return self._inv.bijective
237
238@property
239def sign(self):
240assert self._inv is not None
241return self._inv.sign
242
243@property
244def inv(self):
245return self._inv
246
247def with_cache(self, cache_size=1):
248assert self._inv is not None
249return self.inv.with_cache(cache_size).inv
250
251def __eq__(self, other):
252if not isinstance(other, _InverseTransform):
253return False
254assert self._inv is not None
255return self._inv == other._inv
256
257def __repr__(self):
258return f"{self.__class__.__name__}({repr(self._inv)})"
259
260def __call__(self, x):
261assert self._inv is not None
262return self._inv._inv_call(x)
263
264def log_abs_det_jacobian(self, x, y):
265assert self._inv is not None
266return -self._inv.log_abs_det_jacobian(y, x)
267
268def forward_shape(self, shape):
269return self._inv.inverse_shape(shape)
270
271def inverse_shape(self, shape):
272return self._inv.forward_shape(shape)
273
274
275class ComposeTransform(Transform):
276"""
277Composes multiple transforms in a chain.
278The transforms being composed are responsible for caching.
279
280Args:
281parts (list of :class:`Transform`): A list of transforms to compose.
282cache_size (int): Size of cache. If zero, no caching is done. If one,
283the latest single value is cached. Only 0 and 1 are supported.
284"""
285
286def __init__(self, parts: List[Transform], cache_size=0):
287if cache_size:
288parts = [part.with_cache(cache_size) for part in parts]
289super().__init__(cache_size=cache_size)
290self.parts = parts
291
292def __eq__(self, other):
293if not isinstance(other, ComposeTransform):
294return False
295return self.parts == other.parts
296
297@constraints.dependent_property(is_discrete=False)
298def domain(self):
299if not self.parts:
300return constraints.real
301domain = self.parts[0].domain
302# Adjust event_dim to be maximum among all parts.
303event_dim = self.parts[-1].codomain.event_dim
304for part in reversed(self.parts):
305event_dim += part.domain.event_dim - part.codomain.event_dim
306event_dim = max(event_dim, part.domain.event_dim)
307assert event_dim >= domain.event_dim
308if event_dim > domain.event_dim:
309domain = constraints.independent(domain, event_dim - domain.event_dim)
310return domain
311
312@constraints.dependent_property(is_discrete=False)
313def codomain(self):
314if not self.parts:
315return constraints.real
316codomain = self.parts[-1].codomain
317# Adjust event_dim to be maximum among all parts.
318event_dim = self.parts[0].domain.event_dim
319for part in self.parts:
320event_dim += part.codomain.event_dim - part.domain.event_dim
321event_dim = max(event_dim, part.codomain.event_dim)
322assert event_dim >= codomain.event_dim
323if event_dim > codomain.event_dim:
324codomain = constraints.independent(codomain, event_dim - codomain.event_dim)
325return codomain
326
327@lazy_property
328def bijective(self):
329return all(p.bijective for p in self.parts)
330
331@lazy_property
332def sign(self):
333sign = 1
334for p in self.parts:
335sign = sign * p.sign
336return sign
337
338@property
339def inv(self):
340inv = None
341if self._inv is not None:
342inv = self._inv()
343if inv is None:
344inv = ComposeTransform([p.inv for p in reversed(self.parts)])
345self._inv = weakref.ref(inv)
346inv._inv = weakref.ref(self)
347return inv
348
349def with_cache(self, cache_size=1):
350if self._cache_size == cache_size:
351return self
352return ComposeTransform(self.parts, cache_size=cache_size)
353
354def __call__(self, x):
355for part in self.parts:
356x = part(x)
357return x
358
359def log_abs_det_jacobian(self, x, y):
360if not self.parts:
361return torch.zeros_like(x)
362
363# Compute intermediates. This will be free if parts[:-1] are all cached.
364xs = [x]
365for part in self.parts[:-1]:
366xs.append(part(xs[-1]))
367xs.append(y)
368
369terms = []
370event_dim = self.domain.event_dim
371for part, x, y in zip(self.parts, xs[:-1], xs[1:]):
372terms.append(
373_sum_rightmost(
374part.log_abs_det_jacobian(x, y), event_dim - part.domain.event_dim
375)
376)
377event_dim += part.codomain.event_dim - part.domain.event_dim
378return functools.reduce(operator.add, terms)
379
380def forward_shape(self, shape):
381for part in self.parts:
382shape = part.forward_shape(shape)
383return shape
384
385def inverse_shape(self, shape):
386for part in reversed(self.parts):
387shape = part.inverse_shape(shape)
388return shape
389
390def __repr__(self):
391fmt_string = self.__class__.__name__ + "(\n "
392fmt_string += ",\n ".join([p.__repr__() for p in self.parts])
393fmt_string += "\n)"
394return fmt_string
395
396
397identity_transform = ComposeTransform([])
398
399
400class IndependentTransform(Transform):
401"""
402Wrapper around another transform to treat
403``reinterpreted_batch_ndims``-many extra of the right most dimensions as
404dependent. This has no effect on the forward or backward transforms, but
405does sum out ``reinterpreted_batch_ndims``-many of the rightmost dimensions
406in :meth:`log_abs_det_jacobian`.
407
408Args:
409base_transform (:class:`Transform`): A base transform.
410reinterpreted_batch_ndims (int): The number of extra rightmost
411dimensions to treat as dependent.
412"""
413
414def __init__(self, base_transform, reinterpreted_batch_ndims, cache_size=0):
415super().__init__(cache_size=cache_size)
416self.base_transform = base_transform.with_cache(cache_size)
417self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
418
419def with_cache(self, cache_size=1):
420if self._cache_size == cache_size:
421return self
422return IndependentTransform(
423self.base_transform, self.reinterpreted_batch_ndims, cache_size=cache_size
424)
425
426@constraints.dependent_property(is_discrete=False)
427def domain(self):
428return constraints.independent(
429self.base_transform.domain, self.reinterpreted_batch_ndims
430)
431
432@constraints.dependent_property(is_discrete=False)
433def codomain(self):
434return constraints.independent(
435self.base_transform.codomain, self.reinterpreted_batch_ndims
436)
437
438@property
439def bijective(self):
440return self.base_transform.bijective
441
442@property
443def sign(self):
444return self.base_transform.sign
445
446def _call(self, x):
447if x.dim() < self.domain.event_dim:
448raise ValueError("Too few dimensions on input")
449return self.base_transform(x)
450
451def _inverse(self, y):
452if y.dim() < self.codomain.event_dim:
453raise ValueError("Too few dimensions on input")
454return self.base_transform.inv(y)
455
456def log_abs_det_jacobian(self, x, y):
457result = self.base_transform.log_abs_det_jacobian(x, y)
458result = _sum_rightmost(result, self.reinterpreted_batch_ndims)
459return result
460
461def __repr__(self):
462return f"{self.__class__.__name__}({repr(self.base_transform)}, {self.reinterpreted_batch_ndims})"
463
464def forward_shape(self, shape):
465return self.base_transform.forward_shape(shape)
466
467def inverse_shape(self, shape):
468return self.base_transform.inverse_shape(shape)
469
470
471class ReshapeTransform(Transform):
472"""
473Unit Jacobian transform to reshape the rightmost part of a tensor.
474
475Note that ``in_shape`` and ``out_shape`` must have the same number of
476elements, just as for :meth:`torch.Tensor.reshape`.
477
478Arguments:
479in_shape (torch.Size): The input event shape.
480out_shape (torch.Size): The output event shape.
481"""
482
483bijective = True
484
485def __init__(self, in_shape, out_shape, cache_size=0):
486self.in_shape = torch.Size(in_shape)
487self.out_shape = torch.Size(out_shape)
488if self.in_shape.numel() != self.out_shape.numel():
489raise ValueError("in_shape, out_shape have different numbers of elements")
490super().__init__(cache_size=cache_size)
491
492@constraints.dependent_property
493def domain(self):
494return constraints.independent(constraints.real, len(self.in_shape))
495
496@constraints.dependent_property
497def codomain(self):
498return constraints.independent(constraints.real, len(self.out_shape))
499
500def with_cache(self, cache_size=1):
501if self._cache_size == cache_size:
502return self
503return ReshapeTransform(self.in_shape, self.out_shape, cache_size=cache_size)
504
505def _call(self, x):
506batch_shape = x.shape[: x.dim() - len(self.in_shape)]
507return x.reshape(batch_shape + self.out_shape)
508
509def _inverse(self, y):
510batch_shape = y.shape[: y.dim() - len(self.out_shape)]
511return y.reshape(batch_shape + self.in_shape)
512
513def log_abs_det_jacobian(self, x, y):
514batch_shape = x.shape[: x.dim() - len(self.in_shape)]
515return x.new_zeros(batch_shape)
516
517def forward_shape(self, shape):
518if len(shape) < len(self.in_shape):
519raise ValueError("Too few dimensions on input")
520cut = len(shape) - len(self.in_shape)
521if shape[cut:] != self.in_shape:
522raise ValueError(
523f"Shape mismatch: expected {shape[cut:]} but got {self.in_shape}"
524)
525return shape[:cut] + self.out_shape
526
527def inverse_shape(self, shape):
528if len(shape) < len(self.out_shape):
529raise ValueError("Too few dimensions on input")
530cut = len(shape) - len(self.out_shape)
531if shape[cut:] != self.out_shape:
532raise ValueError(
533f"Shape mismatch: expected {shape[cut:]} but got {self.out_shape}"
534)
535return shape[:cut] + self.in_shape
536
537
538class ExpTransform(Transform):
539r"""
540Transform via the mapping :math:`y = \exp(x)`.
541"""
542domain = constraints.real
543codomain = constraints.positive
544bijective = True
545sign = +1
546
547def __eq__(self, other):
548return isinstance(other, ExpTransform)
549
550def _call(self, x):
551return x.exp()
552
553def _inverse(self, y):
554return y.log()
555
556def log_abs_det_jacobian(self, x, y):
557return x
558
559
560class PowerTransform(Transform):
561r"""
562Transform via the mapping :math:`y = x^{\text{exponent}}`.
563"""
564domain = constraints.positive
565codomain = constraints.positive
566bijective = True
567
568def __init__(self, exponent, cache_size=0):
569super().__init__(cache_size=cache_size)
570(self.exponent,) = broadcast_all(exponent)
571
572def with_cache(self, cache_size=1):
573if self._cache_size == cache_size:
574return self
575return PowerTransform(self.exponent, cache_size=cache_size)
576
577@lazy_property
578def sign(self):
579return self.exponent.sign()
580
581def __eq__(self, other):
582if not isinstance(other, PowerTransform):
583return False
584return self.exponent.eq(other.exponent).all().item()
585
586def _call(self, x):
587return x.pow(self.exponent)
588
589def _inverse(self, y):
590return y.pow(1 / self.exponent)
591
592def log_abs_det_jacobian(self, x, y):
593return (self.exponent * y / x).abs().log()
594
595def forward_shape(self, shape):
596return torch.broadcast_shapes(shape, getattr(self.exponent, "shape", ()))
597
598def inverse_shape(self, shape):
599return torch.broadcast_shapes(shape, getattr(self.exponent, "shape", ()))
600
601
602def _clipped_sigmoid(x):
603finfo = torch.finfo(x.dtype)
604return torch.clamp(torch.sigmoid(x), min=finfo.tiny, max=1.0 - finfo.eps)
605
606
607class SigmoidTransform(Transform):
608r"""
609Transform via the mapping :math:`y = \frac{1}{1 + \exp(-x)}` and :math:`x = \text{logit}(y)`.
610"""
611domain = constraints.real
612codomain = constraints.unit_interval
613bijective = True
614sign = +1
615
616def __eq__(self, other):
617return isinstance(other, SigmoidTransform)
618
619def _call(self, x):
620return _clipped_sigmoid(x)
621
622def _inverse(self, y):
623finfo = torch.finfo(y.dtype)
624y = y.clamp(min=finfo.tiny, max=1.0 - finfo.eps)
625return y.log() - (-y).log1p()
626
627def log_abs_det_jacobian(self, x, y):
628return -F.softplus(-x) - F.softplus(x)
629
630
631class SoftplusTransform(Transform):
632r"""
633Transform via the mapping :math:`\text{Softplus}(x) = \log(1 + \exp(x))`.
634The implementation reverts to the linear function when :math:`x > 20`.
635"""
636domain = constraints.real
637codomain = constraints.positive
638bijective = True
639sign = +1
640
641def __eq__(self, other):
642return isinstance(other, SoftplusTransform)
643
644def _call(self, x):
645return softplus(x)
646
647def _inverse(self, y):
648return (-y).expm1().neg().log() + y
649
650def log_abs_det_jacobian(self, x, y):
651return -softplus(-x)
652
653
654class TanhTransform(Transform):
655r"""
656Transform via the mapping :math:`y = \tanh(x)`.
657
658It is equivalent to
659```
660ComposeTransform([AffineTransform(0., 2.), SigmoidTransform(), AffineTransform(-1., 2.)])
661```
662However this might not be numerically stable, thus it is recommended to use `TanhTransform`
663instead.
664
665Note that one should use `cache_size=1` when it comes to `NaN/Inf` values.
666
667"""
668domain = constraints.real
669codomain = constraints.interval(-1.0, 1.0)
670bijective = True
671sign = +1
672
673def __eq__(self, other):
674return isinstance(other, TanhTransform)
675
676def _call(self, x):
677return x.tanh()
678
679def _inverse(self, y):
680# We do not clamp to the boundary here as it may degrade the performance of certain algorithms.
681# one should use `cache_size=1` instead
682return torch.atanh(y)
683
684def log_abs_det_jacobian(self, x, y):
685# We use a formula that is more numerically stable, see details in the following link
686# https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/tanh.py#L69-L80
687return 2.0 * (math.log(2.0) - x - softplus(-2.0 * x))
688
689
690class AbsTransform(Transform):
691r"""
692Transform via the mapping :math:`y = |x|`.
693"""
694domain = constraints.real
695codomain = constraints.positive
696
697def __eq__(self, other):
698return isinstance(other, AbsTransform)
699
700def _call(self, x):
701return x.abs()
702
703def _inverse(self, y):
704return y
705
706
707class AffineTransform(Transform):
708r"""
709Transform via the pointwise affine mapping :math:`y = \text{loc} + \text{scale} \times x`.
710
711Args:
712loc (Tensor or float): Location parameter.
713scale (Tensor or float): Scale parameter.
714event_dim (int): Optional size of `event_shape`. This should be zero
715for univariate random variables, 1 for distributions over vectors,
7162 for distributions over matrices, etc.
717"""
718bijective = True
719
720def __init__(self, loc, scale, event_dim=0, cache_size=0):
721super().__init__(cache_size=cache_size)
722self.loc = loc
723self.scale = scale
724self._event_dim = event_dim
725
726@property
727def event_dim(self):
728return self._event_dim
729
730@constraints.dependent_property(is_discrete=False)
731def domain(self):
732if self.event_dim == 0:
733return constraints.real
734return constraints.independent(constraints.real, self.event_dim)
735
736@constraints.dependent_property(is_discrete=False)
737def codomain(self):
738if self.event_dim == 0:
739return constraints.real
740return constraints.independent(constraints.real, self.event_dim)
741
742def with_cache(self, cache_size=1):
743if self._cache_size == cache_size:
744return self
745return AffineTransform(
746self.loc, self.scale, self.event_dim, cache_size=cache_size
747)
748
749def __eq__(self, other):
750if not isinstance(other, AffineTransform):
751return False
752
753if isinstance(self.loc, numbers.Number) and isinstance(
754other.loc, numbers.Number
755):
756if self.loc != other.loc:
757return False
758else:
759if not (self.loc == other.loc).all().item():
760return False
761
762if isinstance(self.scale, numbers.Number) and isinstance(
763other.scale, numbers.Number
764):
765if self.scale != other.scale:
766return False
767else:
768if not (self.scale == other.scale).all().item():
769return False
770
771return True
772
773@property
774def sign(self):
775if isinstance(self.scale, numbers.Real):
776return 1 if float(self.scale) > 0 else -1 if float(self.scale) < 0 else 0
777return self.scale.sign()
778
779def _call(self, x):
780return self.loc + self.scale * x
781
782def _inverse(self, y):
783return (y - self.loc) / self.scale
784
785def log_abs_det_jacobian(self, x, y):
786shape = x.shape
787scale = self.scale
788if isinstance(scale, numbers.Real):
789result = torch.full_like(x, math.log(abs(scale)))
790else:
791result = torch.abs(scale).log()
792if self.event_dim:
793result_size = result.size()[: -self.event_dim] + (-1,)
794result = result.view(result_size).sum(-1)
795shape = shape[: -self.event_dim]
796return result.expand(shape)
797
798def forward_shape(self, shape):
799return torch.broadcast_shapes(
800shape, getattr(self.loc, "shape", ()), getattr(self.scale, "shape", ())
801)
802
803def inverse_shape(self, shape):
804return torch.broadcast_shapes(
805shape, getattr(self.loc, "shape", ()), getattr(self.scale, "shape", ())
806)
807
808
809class CorrCholeskyTransform(Transform):
810r"""
811Transforms an uncontrained real vector :math:`x` with length :math:`D*(D-1)/2` into the
812Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower
813triangular matrix with positive diagonals and unit Euclidean norm for each row.
814The transform is processed as follows:
815
8161. First we convert x into a lower triangular matrix in row order.
8172. For each row :math:`X_i` of the lower triangular part, we apply a *signed* version of
818class :class:`StickBreakingTransform` to transform :math:`X_i` into a
819unit Euclidean length vector using the following steps:
820- Scales into the interval :math:`(-1, 1)` domain: :math:`r_i = \tanh(X_i)`.
821- Transforms into an unsigned domain: :math:`z_i = r_i^2`.
822- Applies :math:`s_i = StickBreakingTransform(z_i)`.
823- Transforms back into signed domain: :math:`y_i = sign(r_i) * \sqrt{s_i}`.
824"""
825domain = constraints.real_vector
826codomain = constraints.corr_cholesky
827bijective = True
828
829def _call(self, x):
830x = torch.tanh(x)
831eps = torch.finfo(x.dtype).eps
832x = x.clamp(min=-1 + eps, max=1 - eps)
833r = vec_to_tril_matrix(x, diag=-1)
834# apply stick-breaking on the squared values
835# Note that y = sign(r) * sqrt(z * z1m_cumprod)
836# = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod)
837z = r**2
838z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1)
839# Diagonal elements must be 1.
840r = r + torch.eye(r.shape[-1], dtype=r.dtype, device=r.device)
841y = r * pad(z1m_cumprod_sqrt[..., :-1], [1, 0], value=1)
842return y
843
844def _inverse(self, y):
845# inverse stick-breaking
846# See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html
847y_cumsum = 1 - torch.cumsum(y * y, dim=-1)
848y_cumsum_shifted = pad(y_cumsum[..., :-1], [1, 0], value=1)
849y_vec = tril_matrix_to_vec(y, diag=-1)
850y_cumsum_vec = tril_matrix_to_vec(y_cumsum_shifted, diag=-1)
851t = y_vec / (y_cumsum_vec).sqrt()
852# inverse of tanh
853x = (t.log1p() - t.neg().log1p()) / 2
854return x
855
856def log_abs_det_jacobian(self, x, y, intermediates=None):
857# Because domain and codomain are two spaces with different dimensions, determinant of
858# Jacobian is not well-defined. We return `log_abs_det_jacobian` of `x` and the
859# flattened lower triangular part of `y`.
860
861# See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html
862y1m_cumsum = 1 - (y * y).cumsum(dim=-1)
863# by taking diagonal=-2, we don't need to shift z_cumprod to the right
864# also works for 2 x 2 matrix
865y1m_cumsum_tril = tril_matrix_to_vec(y1m_cumsum, diag=-2)
866stick_breaking_logdet = 0.5 * (y1m_cumsum_tril).log().sum(-1)
867tanh_logdet = -2 * (x + softplus(-2 * x) - math.log(2.0)).sum(dim=-1)
868return stick_breaking_logdet + tanh_logdet
869
870def forward_shape(self, shape):
871# Reshape from (..., N) to (..., D, D).
872if len(shape) < 1:
873raise ValueError("Too few dimensions on input")
874N = shape[-1]
875D = round((0.25 + 2 * N) ** 0.5 + 0.5)
876if D * (D - 1) // 2 != N:
877raise ValueError("Input is not a flattend lower-diagonal number")
878return shape[:-1] + (D, D)
879
880def inverse_shape(self, shape):
881# Reshape from (..., D, D) to (..., N).
882if len(shape) < 2:
883raise ValueError("Too few dimensions on input")
884if shape[-2] != shape[-1]:
885raise ValueError("Input is not square")
886D = shape[-1]
887N = D * (D - 1) // 2
888return shape[:-2] + (N,)
889
890
891class SoftmaxTransform(Transform):
892r"""
893Transform from unconstrained space to the simplex via :math:`y = \exp(x)` then
894normalizing.
895
896This is not bijective and cannot be used for HMC. However this acts mostly
897coordinate-wise (except for the final normalization), and thus is
898appropriate for coordinate-wise optimization algorithms.
899"""
900domain = constraints.real_vector
901codomain = constraints.simplex
902
903def __eq__(self, other):
904return isinstance(other, SoftmaxTransform)
905
906def _call(self, x):
907logprobs = x
908probs = (logprobs - logprobs.max(-1, True)[0]).exp()
909return probs / probs.sum(-1, True)
910
911def _inverse(self, y):
912probs = y
913return probs.log()
914
915def forward_shape(self, shape):
916if len(shape) < 1:
917raise ValueError("Too few dimensions on input")
918return shape
919
920def inverse_shape(self, shape):
921if len(shape) < 1:
922raise ValueError("Too few dimensions on input")
923return shape
924
925
926class StickBreakingTransform(Transform):
927"""
928Transform from unconstrained space to the simplex of one additional
929dimension via a stick-breaking process.
930
931This transform arises as an iterated sigmoid transform in a stick-breaking
932construction of the `Dirichlet` distribution: the first logit is
933transformed via sigmoid to the first probability and the probability of
934everything else, and then the process recurses.
935
936This is bijective and appropriate for use in HMC; however it mixes
937coordinates together and is less appropriate for optimization.
938"""
939
940domain = constraints.real_vector
941codomain = constraints.simplex
942bijective = True
943
944def __eq__(self, other):
945return isinstance(other, StickBreakingTransform)
946
947def _call(self, x):
948offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1)
949z = _clipped_sigmoid(x - offset.log())
950z_cumprod = (1 - z).cumprod(-1)
951y = pad(z, [0, 1], value=1) * pad(z_cumprod, [1, 0], value=1)
952return y
953
954def _inverse(self, y):
955y_crop = y[..., :-1]
956offset = y.shape[-1] - y.new_ones(y_crop.shape[-1]).cumsum(-1)
957sf = 1 - y_crop.cumsum(-1)
958# we clamp to make sure that sf is positive which sometimes does not
959# happen when y[-1] ~ 0 or y[:-1].sum() ~ 1
960sf = torch.clamp(sf, min=torch.finfo(y.dtype).tiny)
961x = y_crop.log() - sf.log() + offset.log()
962return x
963
964def log_abs_det_jacobian(self, x, y):
965offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1)
966x = x - offset.log()
967# use the identity 1 - sigmoid(x) = exp(-x) * sigmoid(x)
968detJ = (-x + F.logsigmoid(x) + y[..., :-1].log()).sum(-1)
969return detJ
970
971def forward_shape(self, shape):
972if len(shape) < 1:
973raise ValueError("Too few dimensions on input")
974return shape[:-1] + (shape[-1] + 1,)
975
976def inverse_shape(self, shape):
977if len(shape) < 1:
978raise ValueError("Too few dimensions on input")
979return shape[:-1] + (shape[-1] - 1,)
980
981
982class LowerCholeskyTransform(Transform):
983"""
984Transform from unconstrained matrices to lower-triangular matrices with
985nonnegative diagonal entries.
986
987This is useful for parameterizing positive definite matrices in terms of
988their Cholesky factorization.
989"""
990
991domain = constraints.independent(constraints.real, 2)
992codomain = constraints.lower_cholesky
993
994def __eq__(self, other):
995return isinstance(other, LowerCholeskyTransform)
996
997def _call(self, x):
998return x.tril(-1) + x.diagonal(dim1=-2, dim2=-1).exp().diag_embed()
999
1000def _inverse(self, y):
1001return y.tril(-1) + y.diagonal(dim1=-2, dim2=-1).log().diag_embed()
1002
1003
1004class PositiveDefiniteTransform(Transform):
1005"""
1006Transform from unconstrained matrices to positive-definite matrices.
1007"""
1008
1009domain = constraints.independent(constraints.real, 2)
1010codomain = constraints.positive_definite # type: ignore[assignment]
1011
1012def __eq__(self, other):
1013return isinstance(other, PositiveDefiniteTransform)
1014
1015def _call(self, x):
1016x = LowerCholeskyTransform()(x)
1017return x @ x.mT
1018
1019def _inverse(self, y):
1020y = torch.linalg.cholesky(y)
1021return LowerCholeskyTransform().inv(y)
1022
1023
1024class CatTransform(Transform):
1025"""
1026Transform functor that applies a sequence of transforms `tseq`
1027component-wise to each submatrix at `dim`, of length `lengths[dim]`,
1028in a way compatible with :func:`torch.cat`.
1029
1030Example::
1031
1032x0 = torch.cat([torch.range(1, 10), torch.range(1, 10)], dim=0)
1033x = torch.cat([x0, x0], dim=0)
1034t0 = CatTransform([ExpTransform(), identity_transform], dim=0, lengths=[10, 10])
1035t = CatTransform([t0, t0], dim=0, lengths=[20, 20])
1036y = t(x)
1037"""
1038
1039transforms: List[Transform]
1040
1041def __init__(self, tseq, dim=0, lengths=None, cache_size=0):
1042assert all(isinstance(t, Transform) for t in tseq)
1043if cache_size:
1044tseq = [t.with_cache(cache_size) for t in tseq]
1045super().__init__(cache_size=cache_size)
1046self.transforms = list(tseq)
1047if lengths is None:
1048lengths = [1] * len(self.transforms)
1049self.lengths = list(lengths)
1050assert len(self.lengths) == len(self.transforms)
1051self.dim = dim
1052
1053@lazy_property
1054def event_dim(self):
1055return max(t.event_dim for t in self.transforms)
1056
1057@lazy_property
1058def length(self):
1059return sum(self.lengths)
1060
1061def with_cache(self, cache_size=1):
1062if self._cache_size == cache_size:
1063return self
1064return CatTransform(self.transforms, self.dim, self.lengths, cache_size)
1065
1066def _call(self, x):
1067assert -x.dim() <= self.dim < x.dim()
1068assert x.size(self.dim) == self.length
1069yslices = []
1070start = 0
1071for trans, length in zip(self.transforms, self.lengths):
1072xslice = x.narrow(self.dim, start, length)
1073yslices.append(trans(xslice))
1074start = start + length # avoid += for jit compat
1075return torch.cat(yslices, dim=self.dim)
1076
1077def _inverse(self, y):
1078assert -y.dim() <= self.dim < y.dim()
1079assert y.size(self.dim) == self.length
1080xslices = []
1081start = 0
1082for trans, length in zip(self.transforms, self.lengths):
1083yslice = y.narrow(self.dim, start, length)
1084xslices.append(trans.inv(yslice))
1085start = start + length # avoid += for jit compat
1086return torch.cat(xslices, dim=self.dim)
1087
1088def log_abs_det_jacobian(self, x, y):
1089assert -x.dim() <= self.dim < x.dim()
1090assert x.size(self.dim) == self.length
1091assert -y.dim() <= self.dim < y.dim()
1092assert y.size(self.dim) == self.length
1093logdetjacs = []
1094start = 0
1095for trans, length in zip(self.transforms, self.lengths):
1096xslice = x.narrow(self.dim, start, length)
1097yslice = y.narrow(self.dim, start, length)
1098logdetjac = trans.log_abs_det_jacobian(xslice, yslice)
1099if trans.event_dim < self.event_dim:
1100logdetjac = _sum_rightmost(logdetjac, self.event_dim - trans.event_dim)
1101logdetjacs.append(logdetjac)
1102start = start + length # avoid += for jit compat
1103# Decide whether to concatenate or sum.
1104dim = self.dim
1105if dim >= 0:
1106dim = dim - x.dim()
1107dim = dim + self.event_dim
1108if dim < 0:
1109return torch.cat(logdetjacs, dim=dim)
1110else:
1111return sum(logdetjacs)
1112
1113@property
1114def bijective(self):
1115return all(t.bijective for t in self.transforms)
1116
1117@constraints.dependent_property
1118def domain(self):
1119return constraints.cat(
1120[t.domain for t in self.transforms], self.dim, self.lengths
1121)
1122
1123@constraints.dependent_property
1124def codomain(self):
1125return constraints.cat(
1126[t.codomain for t in self.transforms], self.dim, self.lengths
1127)
1128
1129
1130class StackTransform(Transform):
1131"""
1132Transform functor that applies a sequence of transforms `tseq`
1133component-wise to each submatrix at `dim`
1134in a way compatible with :func:`torch.stack`.
1135
1136Example::
1137
1138x = torch.stack([torch.range(1, 10), torch.range(1, 10)], dim=1)
1139t = StackTransform([ExpTransform(), identity_transform], dim=1)
1140y = t(x)
1141"""
1142
1143transforms: List[Transform]
1144
1145def __init__(self, tseq, dim=0, cache_size=0):
1146assert all(isinstance(t, Transform) for t in tseq)
1147if cache_size:
1148tseq = [t.with_cache(cache_size) for t in tseq]
1149super().__init__(cache_size=cache_size)
1150self.transforms = list(tseq)
1151self.dim = dim
1152
1153def with_cache(self, cache_size=1):
1154if self._cache_size == cache_size:
1155return self
1156return StackTransform(self.transforms, self.dim, cache_size)
1157
1158def _slice(self, z):
1159return [z.select(self.dim, i) for i in range(z.size(self.dim))]
1160
1161def _call(self, x):
1162assert -x.dim() <= self.dim < x.dim()
1163assert x.size(self.dim) == len(self.transforms)
1164yslices = []
1165for xslice, trans in zip(self._slice(x), self.transforms):
1166yslices.append(trans(xslice))
1167return torch.stack(yslices, dim=self.dim)
1168
1169def _inverse(self, y):
1170assert -y.dim() <= self.dim < y.dim()
1171assert y.size(self.dim) == len(self.transforms)
1172xslices = []
1173for yslice, trans in zip(self._slice(y), self.transforms):
1174xslices.append(trans.inv(yslice))
1175return torch.stack(xslices, dim=self.dim)
1176
1177def log_abs_det_jacobian(self, x, y):
1178assert -x.dim() <= self.dim < x.dim()
1179assert x.size(self.dim) == len(self.transforms)
1180assert -y.dim() <= self.dim < y.dim()
1181assert y.size(self.dim) == len(self.transforms)
1182logdetjacs = []
1183yslices = self._slice(y)
1184xslices = self._slice(x)
1185for xslice, yslice, trans in zip(xslices, yslices, self.transforms):
1186logdetjacs.append(trans.log_abs_det_jacobian(xslice, yslice))
1187return torch.stack(logdetjacs, dim=self.dim)
1188
1189@property
1190def bijective(self):
1191return all(t.bijective for t in self.transforms)
1192
1193@constraints.dependent_property
1194def domain(self):
1195return constraints.stack([t.domain for t in self.transforms], self.dim)
1196
1197@constraints.dependent_property
1198def codomain(self):
1199return constraints.stack([t.codomain for t in self.transforms], self.dim)
1200
1201
1202class CumulativeDistributionTransform(Transform):
1203"""
1204Transform via the cumulative distribution function of a probability distribution.
1205
1206Args:
1207distribution (Distribution): Distribution whose cumulative distribution function to use for
1208the transformation.
1209
1210Example::
1211
1212# Construct a Gaussian copula from a multivariate normal.
1213base_dist = MultivariateNormal(
1214loc=torch.zeros(2),
1215scale_tril=LKJCholesky(2).sample(),
1216)
1217transform = CumulativeDistributionTransform(Normal(0, 1))
1218copula = TransformedDistribution(base_dist, [transform])
1219"""
1220
1221bijective = True
1222codomain = constraints.unit_interval
1223sign = +1
1224
1225def __init__(self, distribution, cache_size=0):
1226super().__init__(cache_size=cache_size)
1227self.distribution = distribution
1228
1229@property
1230def domain(self):
1231return self.distribution.support
1232
1233def _call(self, x):
1234return self.distribution.cdf(x)
1235
1236def _inverse(self, y):
1237return self.distribution.icdf(y)
1238
1239def log_abs_det_jacobian(self, x, y):
1240return self.distribution.log_prob(x)
1241
1242def with_cache(self, cache_size=1):
1243if self._cache_size == cache_size:
1244return self
1245return CumulativeDistributionTransform(self.distribution, cache_size=cache_size)
1246