pytorch

Форк
0
/
transforms.py 
1245 строк · 40.0 Кб
1
import functools
2
import math
3
import numbers
4
import operator
5
import weakref
6
from typing import List
7

8
import torch
9
import torch.nn.functional as F
10
from torch.distributions import constraints
11
from torch.distributions.utils import (
12
    _sum_rightmost,
13
    broadcast_all,
14
    lazy_property,
15
    tril_matrix_to_vec,
16
    vec_to_tril_matrix,
17
)
18
from torch.nn.functional import pad, softplus
19

20
__all__ = [
21
    "AbsTransform",
22
    "AffineTransform",
23
    "CatTransform",
24
    "ComposeTransform",
25
    "CorrCholeskyTransform",
26
    "CumulativeDistributionTransform",
27
    "ExpTransform",
28
    "IndependentTransform",
29
    "LowerCholeskyTransform",
30
    "PositiveDefiniteTransform",
31
    "PowerTransform",
32
    "ReshapeTransform",
33
    "SigmoidTransform",
34
    "SoftplusTransform",
35
    "TanhTransform",
36
    "SoftmaxTransform",
37
    "StackTransform",
38
    "StickBreakingTransform",
39
    "Transform",
40
    "identity_transform",
41
]
42

43

44
class Transform:
45
    """
46
    Abstract class for invertable transformations with computable log
47
    det jacobians. They are primarily used in
48
    :class:`torch.distributions.TransformedDistribution`.
49

50
    Caching is useful for transforms whose inverses are either expensive or
51
    numerically unstable. Note that care must be taken with memoized values
52
    since the autograd graph may be reversed. For example while the following
53
    works with or without caching::
54

55
        y = t(x)
56
        t.log_abs_det_jacobian(x, y).backward()  # x will receive gradients.
57

58
    However the following will error when caching due to dependency reversal::
59

60
        y = t(x)
61
        z = t.inv(y)
62
        grad(z.sum(), [y])  # error because z is x
63

64
    Derived classes should implement one or both of :meth:`_call` or
65
    :meth:`_inverse`. Derived classes that set `bijective=True` should also
66
    implement :meth:`log_abs_det_jacobian`.
67

68
    Args:
69
        cache_size (int): Size of cache. If zero, no caching is done. If one,
70
            the latest single value is cached. Only 0 and 1 are supported.
71

72
    Attributes:
73
        domain (:class:`~torch.distributions.constraints.Constraint`):
74
            The constraint representing valid inputs to this transform.
75
        codomain (:class:`~torch.distributions.constraints.Constraint`):
76
            The constraint representing valid outputs to this transform
77
            which are inputs to the inverse transform.
78
        bijective (bool): Whether this transform is bijective. A transform
79
            ``t`` is bijective iff ``t.inv(t(x)) == x`` and
80
            ``t(t.inv(y)) == y`` for every ``x`` in the domain and ``y`` in
81
            the codomain. Transforms that are not bijective should at least
82
            maintain the weaker pseudoinverse properties
83
            ``t(t.inv(t(x)) == t(x)`` and ``t.inv(t(t.inv(y))) == t.inv(y)``.
84
        sign (int or Tensor): For bijective univariate transforms, this
85
            should be +1 or -1 depending on whether transform is monotone
86
            increasing or decreasing.
87
    """
88

89
    bijective = False
90
    domain: constraints.Constraint
91
    codomain: constraints.Constraint
92

93
    def __init__(self, cache_size=0):
94
        self._cache_size = cache_size
95
        self._inv = None
96
        if cache_size == 0:
97
            pass  # default behavior
98
        elif cache_size == 1:
99
            self._cached_x_y = None, None
100
        else:
101
            raise ValueError("cache_size must be 0 or 1")
102
        super().__init__()
103

104
    def __getstate__(self):
105
        state = self.__dict__.copy()
106
        state["_inv"] = None
107
        return state
108

109
    @property
110
    def event_dim(self):
111
        if self.domain.event_dim == self.codomain.event_dim:
112
            return self.domain.event_dim
113
        raise ValueError("Please use either .domain.event_dim or .codomain.event_dim")
114

115
    @property
116
    def inv(self):
117
        """
118
        Returns the inverse :class:`Transform` of this transform.
119
        This should satisfy ``t.inv.inv is t``.
120
        """
121
        inv = None
122
        if self._inv is not None:
123
            inv = self._inv()
124
        if inv is None:
125
            inv = _InverseTransform(self)
126
            self._inv = weakref.ref(inv)
127
        return inv
128

129
    @property
130
    def sign(self):
131
        """
132
        Returns the sign of the determinant of the Jacobian, if applicable.
133
        In general this only makes sense for bijective transforms.
134
        """
135
        raise NotImplementedError
136

137
    def with_cache(self, cache_size=1):
138
        if self._cache_size == cache_size:
139
            return self
140
        if type(self).__init__ is Transform.__init__:
141
            return type(self)(cache_size=cache_size)
142
        raise NotImplementedError(f"{type(self)}.with_cache is not implemented")
143

144
    def __eq__(self, other):
145
        return self is other
146

147
    def __ne__(self, other):
148
        # Necessary for Python2
149
        return not self.__eq__(other)
150

151
    def __call__(self, x):
152
        """
153
        Computes the transform `x => y`.
154
        """
155
        if self._cache_size == 0:
156
            return self._call(x)
157
        x_old, y_old = self._cached_x_y
158
        if x is x_old:
159
            return y_old
160
        y = self._call(x)
161
        self._cached_x_y = x, y
162
        return y
163

164
    def _inv_call(self, y):
165
        """
166
        Inverts the transform `y => x`.
167
        """
168
        if self._cache_size == 0:
169
            return self._inverse(y)
170
        x_old, y_old = self._cached_x_y
171
        if y is y_old:
172
            return x_old
173
        x = self._inverse(y)
174
        self._cached_x_y = x, y
175
        return x
176

177
    def _call(self, x):
178
        """
179
        Abstract method to compute forward transformation.
180
        """
181
        raise NotImplementedError
182

183
    def _inverse(self, y):
184
        """
185
        Abstract method to compute inverse transformation.
186
        """
187
        raise NotImplementedError
188

189
    def log_abs_det_jacobian(self, x, y):
190
        """
191
        Computes the log det jacobian `log |dy/dx|` given input and output.
192
        """
193
        raise NotImplementedError
194

195
    def __repr__(self):
196
        return self.__class__.__name__ + "()"
197

198
    def forward_shape(self, shape):
199
        """
200
        Infers the shape of the forward computation, given the input shape.
201
        Defaults to preserving shape.
202
        """
203
        return shape
204

205
    def inverse_shape(self, shape):
206
        """
207
        Infers the shapes of the inverse computation, given the output shape.
208
        Defaults to preserving shape.
209
        """
210
        return shape
211

212

213
class _InverseTransform(Transform):
214
    """
215
    Inverts a single :class:`Transform`.
216
    This class is private; please instead use the ``Transform.inv`` property.
217
    """
218

219
    def __init__(self, transform: Transform):
220
        super().__init__(cache_size=transform._cache_size)
221
        self._inv: Transform = transform
222

223
    @constraints.dependent_property(is_discrete=False)
224
    def domain(self):
225
        assert self._inv is not None
226
        return self._inv.codomain
227

228
    @constraints.dependent_property(is_discrete=False)
229
    def codomain(self):
230
        assert self._inv is not None
231
        return self._inv.domain
232

233
    @property
234
    def bijective(self):
235
        assert self._inv is not None
236
        return self._inv.bijective
237

238
    @property
239
    def sign(self):
240
        assert self._inv is not None
241
        return self._inv.sign
242

243
    @property
244
    def inv(self):
245
        return self._inv
246

247
    def with_cache(self, cache_size=1):
248
        assert self._inv is not None
249
        return self.inv.with_cache(cache_size).inv
250

251
    def __eq__(self, other):
252
        if not isinstance(other, _InverseTransform):
253
            return False
254
        assert self._inv is not None
255
        return self._inv == other._inv
256

257
    def __repr__(self):
258
        return f"{self.__class__.__name__}({repr(self._inv)})"
259

260
    def __call__(self, x):
261
        assert self._inv is not None
262
        return self._inv._inv_call(x)
263

264
    def log_abs_det_jacobian(self, x, y):
265
        assert self._inv is not None
266
        return -self._inv.log_abs_det_jacobian(y, x)
267

268
    def forward_shape(self, shape):
269
        return self._inv.inverse_shape(shape)
270

271
    def inverse_shape(self, shape):
272
        return self._inv.forward_shape(shape)
273

274

275
class ComposeTransform(Transform):
276
    """
277
    Composes multiple transforms in a chain.
278
    The transforms being composed are responsible for caching.
279

280
    Args:
281
        parts (list of :class:`Transform`): A list of transforms to compose.
282
        cache_size (int): Size of cache. If zero, no caching is done. If one,
283
            the latest single value is cached. Only 0 and 1 are supported.
284
    """
285

286
    def __init__(self, parts: List[Transform], cache_size=0):
287
        if cache_size:
288
            parts = [part.with_cache(cache_size) for part in parts]
289
        super().__init__(cache_size=cache_size)
290
        self.parts = parts
291

292
    def __eq__(self, other):
293
        if not isinstance(other, ComposeTransform):
294
            return False
295
        return self.parts == other.parts
296

297
    @constraints.dependent_property(is_discrete=False)
298
    def domain(self):
299
        if not self.parts:
300
            return constraints.real
301
        domain = self.parts[0].domain
302
        # Adjust event_dim to be maximum among all parts.
303
        event_dim = self.parts[-1].codomain.event_dim
304
        for part in reversed(self.parts):
305
            event_dim += part.domain.event_dim - part.codomain.event_dim
306
            event_dim = max(event_dim, part.domain.event_dim)
307
        assert event_dim >= domain.event_dim
308
        if event_dim > domain.event_dim:
309
            domain = constraints.independent(domain, event_dim - domain.event_dim)
310
        return domain
311

312
    @constraints.dependent_property(is_discrete=False)
313
    def codomain(self):
314
        if not self.parts:
315
            return constraints.real
316
        codomain = self.parts[-1].codomain
317
        # Adjust event_dim to be maximum among all parts.
318
        event_dim = self.parts[0].domain.event_dim
319
        for part in self.parts:
320
            event_dim += part.codomain.event_dim - part.domain.event_dim
321
            event_dim = max(event_dim, part.codomain.event_dim)
322
        assert event_dim >= codomain.event_dim
323
        if event_dim > codomain.event_dim:
324
            codomain = constraints.independent(codomain, event_dim - codomain.event_dim)
325
        return codomain
326

327
    @lazy_property
328
    def bijective(self):
329
        return all(p.bijective for p in self.parts)
330

331
    @lazy_property
332
    def sign(self):
333
        sign = 1
334
        for p in self.parts:
335
            sign = sign * p.sign
336
        return sign
337

338
    @property
339
    def inv(self):
340
        inv = None
341
        if self._inv is not None:
342
            inv = self._inv()
343
        if inv is None:
344
            inv = ComposeTransform([p.inv for p in reversed(self.parts)])
345
            self._inv = weakref.ref(inv)
346
            inv._inv = weakref.ref(self)
347
        return inv
348

349
    def with_cache(self, cache_size=1):
350
        if self._cache_size == cache_size:
351
            return self
352
        return ComposeTransform(self.parts, cache_size=cache_size)
353

354
    def __call__(self, x):
355
        for part in self.parts:
356
            x = part(x)
357
        return x
358

359
    def log_abs_det_jacobian(self, x, y):
360
        if not self.parts:
361
            return torch.zeros_like(x)
362

363
        # Compute intermediates. This will be free if parts[:-1] are all cached.
364
        xs = [x]
365
        for part in self.parts[:-1]:
366
            xs.append(part(xs[-1]))
367
        xs.append(y)
368

369
        terms = []
370
        event_dim = self.domain.event_dim
371
        for part, x, y in zip(self.parts, xs[:-1], xs[1:]):
372
            terms.append(
373
                _sum_rightmost(
374
                    part.log_abs_det_jacobian(x, y), event_dim - part.domain.event_dim
375
                )
376
            )
377
            event_dim += part.codomain.event_dim - part.domain.event_dim
378
        return functools.reduce(operator.add, terms)
379

380
    def forward_shape(self, shape):
381
        for part in self.parts:
382
            shape = part.forward_shape(shape)
383
        return shape
384

385
    def inverse_shape(self, shape):
386
        for part in reversed(self.parts):
387
            shape = part.inverse_shape(shape)
388
        return shape
389

390
    def __repr__(self):
391
        fmt_string = self.__class__.__name__ + "(\n    "
392
        fmt_string += ",\n    ".join([p.__repr__() for p in self.parts])
393
        fmt_string += "\n)"
394
        return fmt_string
395

396

397
identity_transform = ComposeTransform([])
398

399

400
class IndependentTransform(Transform):
401
    """
402
    Wrapper around another transform to treat
403
    ``reinterpreted_batch_ndims``-many extra of the right most dimensions as
404
    dependent. This has no effect on the forward or backward transforms, but
405
    does sum out ``reinterpreted_batch_ndims``-many of the rightmost dimensions
406
    in :meth:`log_abs_det_jacobian`.
407

408
    Args:
409
        base_transform (:class:`Transform`): A base transform.
410
        reinterpreted_batch_ndims (int): The number of extra rightmost
411
            dimensions to treat as dependent.
412
    """
413

414
    def __init__(self, base_transform, reinterpreted_batch_ndims, cache_size=0):
415
        super().__init__(cache_size=cache_size)
416
        self.base_transform = base_transform.with_cache(cache_size)
417
        self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
418

419
    def with_cache(self, cache_size=1):
420
        if self._cache_size == cache_size:
421
            return self
422
        return IndependentTransform(
423
            self.base_transform, self.reinterpreted_batch_ndims, cache_size=cache_size
424
        )
425

426
    @constraints.dependent_property(is_discrete=False)
427
    def domain(self):
428
        return constraints.independent(
429
            self.base_transform.domain, self.reinterpreted_batch_ndims
430
        )
431

432
    @constraints.dependent_property(is_discrete=False)
433
    def codomain(self):
434
        return constraints.independent(
435
            self.base_transform.codomain, self.reinterpreted_batch_ndims
436
        )
437

438
    @property
439
    def bijective(self):
440
        return self.base_transform.bijective
441

442
    @property
443
    def sign(self):
444
        return self.base_transform.sign
445

446
    def _call(self, x):
447
        if x.dim() < self.domain.event_dim:
448
            raise ValueError("Too few dimensions on input")
449
        return self.base_transform(x)
450

451
    def _inverse(self, y):
452
        if y.dim() < self.codomain.event_dim:
453
            raise ValueError("Too few dimensions on input")
454
        return self.base_transform.inv(y)
455

456
    def log_abs_det_jacobian(self, x, y):
457
        result = self.base_transform.log_abs_det_jacobian(x, y)
458
        result = _sum_rightmost(result, self.reinterpreted_batch_ndims)
459
        return result
460

461
    def __repr__(self):
462
        return f"{self.__class__.__name__}({repr(self.base_transform)}, {self.reinterpreted_batch_ndims})"
463

464
    def forward_shape(self, shape):
465
        return self.base_transform.forward_shape(shape)
466

467
    def inverse_shape(self, shape):
468
        return self.base_transform.inverse_shape(shape)
469

470

471
class ReshapeTransform(Transform):
472
    """
473
    Unit Jacobian transform to reshape the rightmost part of a tensor.
474

475
    Note that ``in_shape`` and ``out_shape`` must have the same number of
476
    elements, just as for :meth:`torch.Tensor.reshape`.
477

478
    Arguments:
479
        in_shape (torch.Size): The input event shape.
480
        out_shape (torch.Size): The output event shape.
481
    """
482

483
    bijective = True
484

485
    def __init__(self, in_shape, out_shape, cache_size=0):
486
        self.in_shape = torch.Size(in_shape)
487
        self.out_shape = torch.Size(out_shape)
488
        if self.in_shape.numel() != self.out_shape.numel():
489
            raise ValueError("in_shape, out_shape have different numbers of elements")
490
        super().__init__(cache_size=cache_size)
491

492
    @constraints.dependent_property
493
    def domain(self):
494
        return constraints.independent(constraints.real, len(self.in_shape))
495

496
    @constraints.dependent_property
497
    def codomain(self):
498
        return constraints.independent(constraints.real, len(self.out_shape))
499

500
    def with_cache(self, cache_size=1):
501
        if self._cache_size == cache_size:
502
            return self
503
        return ReshapeTransform(self.in_shape, self.out_shape, cache_size=cache_size)
504

505
    def _call(self, x):
506
        batch_shape = x.shape[: x.dim() - len(self.in_shape)]
507
        return x.reshape(batch_shape + self.out_shape)
508

509
    def _inverse(self, y):
510
        batch_shape = y.shape[: y.dim() - len(self.out_shape)]
511
        return y.reshape(batch_shape + self.in_shape)
512

513
    def log_abs_det_jacobian(self, x, y):
514
        batch_shape = x.shape[: x.dim() - len(self.in_shape)]
515
        return x.new_zeros(batch_shape)
516

517
    def forward_shape(self, shape):
518
        if len(shape) < len(self.in_shape):
519
            raise ValueError("Too few dimensions on input")
520
        cut = len(shape) - len(self.in_shape)
521
        if shape[cut:] != self.in_shape:
522
            raise ValueError(
523
                f"Shape mismatch: expected {shape[cut:]} but got {self.in_shape}"
524
            )
525
        return shape[:cut] + self.out_shape
526

527
    def inverse_shape(self, shape):
528
        if len(shape) < len(self.out_shape):
529
            raise ValueError("Too few dimensions on input")
530
        cut = len(shape) - len(self.out_shape)
531
        if shape[cut:] != self.out_shape:
532
            raise ValueError(
533
                f"Shape mismatch: expected {shape[cut:]} but got {self.out_shape}"
534
            )
535
        return shape[:cut] + self.in_shape
536

537

538
class ExpTransform(Transform):
539
    r"""
540
    Transform via the mapping :math:`y = \exp(x)`.
541
    """
542
    domain = constraints.real
543
    codomain = constraints.positive
544
    bijective = True
545
    sign = +1
546

547
    def __eq__(self, other):
548
        return isinstance(other, ExpTransform)
549

550
    def _call(self, x):
551
        return x.exp()
552

553
    def _inverse(self, y):
554
        return y.log()
555

556
    def log_abs_det_jacobian(self, x, y):
557
        return x
558

559

560
class PowerTransform(Transform):
561
    r"""
562
    Transform via the mapping :math:`y = x^{\text{exponent}}`.
563
    """
564
    domain = constraints.positive
565
    codomain = constraints.positive
566
    bijective = True
567

568
    def __init__(self, exponent, cache_size=0):
569
        super().__init__(cache_size=cache_size)
570
        (self.exponent,) = broadcast_all(exponent)
571

572
    def with_cache(self, cache_size=1):
573
        if self._cache_size == cache_size:
574
            return self
575
        return PowerTransform(self.exponent, cache_size=cache_size)
576

577
    @lazy_property
578
    def sign(self):
579
        return self.exponent.sign()
580

581
    def __eq__(self, other):
582
        if not isinstance(other, PowerTransform):
583
            return False
584
        return self.exponent.eq(other.exponent).all().item()
585

586
    def _call(self, x):
587
        return x.pow(self.exponent)
588

589
    def _inverse(self, y):
590
        return y.pow(1 / self.exponent)
591

592
    def log_abs_det_jacobian(self, x, y):
593
        return (self.exponent * y / x).abs().log()
594

595
    def forward_shape(self, shape):
596
        return torch.broadcast_shapes(shape, getattr(self.exponent, "shape", ()))
597

598
    def inverse_shape(self, shape):
599
        return torch.broadcast_shapes(shape, getattr(self.exponent, "shape", ()))
600

601

602
def _clipped_sigmoid(x):
603
    finfo = torch.finfo(x.dtype)
604
    return torch.clamp(torch.sigmoid(x), min=finfo.tiny, max=1.0 - finfo.eps)
605

606

607
class SigmoidTransform(Transform):
608
    r"""
609
    Transform via the mapping :math:`y = \frac{1}{1 + \exp(-x)}` and :math:`x = \text{logit}(y)`.
610
    """
611
    domain = constraints.real
612
    codomain = constraints.unit_interval
613
    bijective = True
614
    sign = +1
615

616
    def __eq__(self, other):
617
        return isinstance(other, SigmoidTransform)
618

619
    def _call(self, x):
620
        return _clipped_sigmoid(x)
621

622
    def _inverse(self, y):
623
        finfo = torch.finfo(y.dtype)
624
        y = y.clamp(min=finfo.tiny, max=1.0 - finfo.eps)
625
        return y.log() - (-y).log1p()
626

627
    def log_abs_det_jacobian(self, x, y):
628
        return -F.softplus(-x) - F.softplus(x)
629

630

631
class SoftplusTransform(Transform):
632
    r"""
633
    Transform via the mapping :math:`\text{Softplus}(x) = \log(1 + \exp(x))`.
634
    The implementation reverts to the linear function when :math:`x > 20`.
635
    """
636
    domain = constraints.real
637
    codomain = constraints.positive
638
    bijective = True
639
    sign = +1
640

641
    def __eq__(self, other):
642
        return isinstance(other, SoftplusTransform)
643

644
    def _call(self, x):
645
        return softplus(x)
646

647
    def _inverse(self, y):
648
        return (-y).expm1().neg().log() + y
649

650
    def log_abs_det_jacobian(self, x, y):
651
        return -softplus(-x)
652

653

654
class TanhTransform(Transform):
655
    r"""
656
    Transform via the mapping :math:`y = \tanh(x)`.
657

658
    It is equivalent to
659
    ```
660
    ComposeTransform([AffineTransform(0., 2.), SigmoidTransform(), AffineTransform(-1., 2.)])
661
    ```
662
    However this might not be numerically stable, thus it is recommended to use `TanhTransform`
663
    instead.
664

665
    Note that one should use `cache_size=1` when it comes to `NaN/Inf` values.
666

667
    """
668
    domain = constraints.real
669
    codomain = constraints.interval(-1.0, 1.0)
670
    bijective = True
671
    sign = +1
672

673
    def __eq__(self, other):
674
        return isinstance(other, TanhTransform)
675

676
    def _call(self, x):
677
        return x.tanh()
678

679
    def _inverse(self, y):
680
        # We do not clamp to the boundary here as it may degrade the performance of certain algorithms.
681
        # one should use `cache_size=1` instead
682
        return torch.atanh(y)
683

684
    def log_abs_det_jacobian(self, x, y):
685
        # We use a formula that is more numerically stable, see details in the following link
686
        # https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/tanh.py#L69-L80
687
        return 2.0 * (math.log(2.0) - x - softplus(-2.0 * x))
688

689

690
class AbsTransform(Transform):
691
    r"""
692
    Transform via the mapping :math:`y = |x|`.
693
    """
694
    domain = constraints.real
695
    codomain = constraints.positive
696

697
    def __eq__(self, other):
698
        return isinstance(other, AbsTransform)
699

700
    def _call(self, x):
701
        return x.abs()
702

703
    def _inverse(self, y):
704
        return y
705

706

707
class AffineTransform(Transform):
708
    r"""
709
    Transform via the pointwise affine mapping :math:`y = \text{loc} + \text{scale} \times x`.
710

711
    Args:
712
        loc (Tensor or float): Location parameter.
713
        scale (Tensor or float): Scale parameter.
714
        event_dim (int): Optional size of `event_shape`. This should be zero
715
            for univariate random variables, 1 for distributions over vectors,
716
            2 for distributions over matrices, etc.
717
    """
718
    bijective = True
719

720
    def __init__(self, loc, scale, event_dim=0, cache_size=0):
721
        super().__init__(cache_size=cache_size)
722
        self.loc = loc
723
        self.scale = scale
724
        self._event_dim = event_dim
725

726
    @property
727
    def event_dim(self):
728
        return self._event_dim
729

730
    @constraints.dependent_property(is_discrete=False)
731
    def domain(self):
732
        if self.event_dim == 0:
733
            return constraints.real
734
        return constraints.independent(constraints.real, self.event_dim)
735

736
    @constraints.dependent_property(is_discrete=False)
737
    def codomain(self):
738
        if self.event_dim == 0:
739
            return constraints.real
740
        return constraints.independent(constraints.real, self.event_dim)
741

742
    def with_cache(self, cache_size=1):
743
        if self._cache_size == cache_size:
744
            return self
745
        return AffineTransform(
746
            self.loc, self.scale, self.event_dim, cache_size=cache_size
747
        )
748

749
    def __eq__(self, other):
750
        if not isinstance(other, AffineTransform):
751
            return False
752

753
        if isinstance(self.loc, numbers.Number) and isinstance(
754
            other.loc, numbers.Number
755
        ):
756
            if self.loc != other.loc:
757
                return False
758
        else:
759
            if not (self.loc == other.loc).all().item():
760
                return False
761

762
        if isinstance(self.scale, numbers.Number) and isinstance(
763
            other.scale, numbers.Number
764
        ):
765
            if self.scale != other.scale:
766
                return False
767
        else:
768
            if not (self.scale == other.scale).all().item():
769
                return False
770

771
        return True
772

773
    @property
774
    def sign(self):
775
        if isinstance(self.scale, numbers.Real):
776
            return 1 if float(self.scale) > 0 else -1 if float(self.scale) < 0 else 0
777
        return self.scale.sign()
778

779
    def _call(self, x):
780
        return self.loc + self.scale * x
781

782
    def _inverse(self, y):
783
        return (y - self.loc) / self.scale
784

785
    def log_abs_det_jacobian(self, x, y):
786
        shape = x.shape
787
        scale = self.scale
788
        if isinstance(scale, numbers.Real):
789
            result = torch.full_like(x, math.log(abs(scale)))
790
        else:
791
            result = torch.abs(scale).log()
792
        if self.event_dim:
793
            result_size = result.size()[: -self.event_dim] + (-1,)
794
            result = result.view(result_size).sum(-1)
795
            shape = shape[: -self.event_dim]
796
        return result.expand(shape)
797

798
    def forward_shape(self, shape):
799
        return torch.broadcast_shapes(
800
            shape, getattr(self.loc, "shape", ()), getattr(self.scale, "shape", ())
801
        )
802

803
    def inverse_shape(self, shape):
804
        return torch.broadcast_shapes(
805
            shape, getattr(self.loc, "shape", ()), getattr(self.scale, "shape", ())
806
        )
807

808

809
class CorrCholeskyTransform(Transform):
810
    r"""
811
    Transforms an uncontrained real vector :math:`x` with length :math:`D*(D-1)/2` into the
812
    Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower
813
    triangular matrix with positive diagonals and unit Euclidean norm for each row.
814
    The transform is processed as follows:
815

816
        1. First we convert x into a lower triangular matrix in row order.
817
        2. For each row :math:`X_i` of the lower triangular part, we apply a *signed* version of
818
           class :class:`StickBreakingTransform` to transform :math:`X_i` into a
819
           unit Euclidean length vector using the following steps:
820
           - Scales into the interval :math:`(-1, 1)` domain: :math:`r_i = \tanh(X_i)`.
821
           - Transforms into an unsigned domain: :math:`z_i = r_i^2`.
822
           - Applies :math:`s_i = StickBreakingTransform(z_i)`.
823
           - Transforms back into signed domain: :math:`y_i = sign(r_i) * \sqrt{s_i}`.
824
    """
825
    domain = constraints.real_vector
826
    codomain = constraints.corr_cholesky
827
    bijective = True
828

829
    def _call(self, x):
830
        x = torch.tanh(x)
831
        eps = torch.finfo(x.dtype).eps
832
        x = x.clamp(min=-1 + eps, max=1 - eps)
833
        r = vec_to_tril_matrix(x, diag=-1)
834
        # apply stick-breaking on the squared values
835
        # Note that y = sign(r) * sqrt(z * z1m_cumprod)
836
        #             = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod)
837
        z = r**2
838
        z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1)
839
        # Diagonal elements must be 1.
840
        r = r + torch.eye(r.shape[-1], dtype=r.dtype, device=r.device)
841
        y = r * pad(z1m_cumprod_sqrt[..., :-1], [1, 0], value=1)
842
        return y
843

844
    def _inverse(self, y):
845
        # inverse stick-breaking
846
        # See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html
847
        y_cumsum = 1 - torch.cumsum(y * y, dim=-1)
848
        y_cumsum_shifted = pad(y_cumsum[..., :-1], [1, 0], value=1)
849
        y_vec = tril_matrix_to_vec(y, diag=-1)
850
        y_cumsum_vec = tril_matrix_to_vec(y_cumsum_shifted, diag=-1)
851
        t = y_vec / (y_cumsum_vec).sqrt()
852
        # inverse of tanh
853
        x = (t.log1p() - t.neg().log1p()) / 2
854
        return x
855

856
    def log_abs_det_jacobian(self, x, y, intermediates=None):
857
        # Because domain and codomain are two spaces with different dimensions, determinant of
858
        # Jacobian is not well-defined. We return `log_abs_det_jacobian` of `x` and the
859
        # flattened lower triangular part of `y`.
860

861
        # See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html
862
        y1m_cumsum = 1 - (y * y).cumsum(dim=-1)
863
        # by taking diagonal=-2, we don't need to shift z_cumprod to the right
864
        # also works for 2 x 2 matrix
865
        y1m_cumsum_tril = tril_matrix_to_vec(y1m_cumsum, diag=-2)
866
        stick_breaking_logdet = 0.5 * (y1m_cumsum_tril).log().sum(-1)
867
        tanh_logdet = -2 * (x + softplus(-2 * x) - math.log(2.0)).sum(dim=-1)
868
        return stick_breaking_logdet + tanh_logdet
869

870
    def forward_shape(self, shape):
871
        # Reshape from (..., N) to (..., D, D).
872
        if len(shape) < 1:
873
            raise ValueError("Too few dimensions on input")
874
        N = shape[-1]
875
        D = round((0.25 + 2 * N) ** 0.5 + 0.5)
876
        if D * (D - 1) // 2 != N:
877
            raise ValueError("Input is not a flattend lower-diagonal number")
878
        return shape[:-1] + (D, D)
879

880
    def inverse_shape(self, shape):
881
        # Reshape from (..., D, D) to (..., N).
882
        if len(shape) < 2:
883
            raise ValueError("Too few dimensions on input")
884
        if shape[-2] != shape[-1]:
885
            raise ValueError("Input is not square")
886
        D = shape[-1]
887
        N = D * (D - 1) // 2
888
        return shape[:-2] + (N,)
889

890

891
class SoftmaxTransform(Transform):
892
    r"""
893
    Transform from unconstrained space to the simplex via :math:`y = \exp(x)` then
894
    normalizing.
895

896
    This is not bijective and cannot be used for HMC. However this acts mostly
897
    coordinate-wise (except for the final normalization), and thus is
898
    appropriate for coordinate-wise optimization algorithms.
899
    """
900
    domain = constraints.real_vector
901
    codomain = constraints.simplex
902

903
    def __eq__(self, other):
904
        return isinstance(other, SoftmaxTransform)
905

906
    def _call(self, x):
907
        logprobs = x
908
        probs = (logprobs - logprobs.max(-1, True)[0]).exp()
909
        return probs / probs.sum(-1, True)
910

911
    def _inverse(self, y):
912
        probs = y
913
        return probs.log()
914

915
    def forward_shape(self, shape):
916
        if len(shape) < 1:
917
            raise ValueError("Too few dimensions on input")
918
        return shape
919

920
    def inverse_shape(self, shape):
921
        if len(shape) < 1:
922
            raise ValueError("Too few dimensions on input")
923
        return shape
924

925

926
class StickBreakingTransform(Transform):
927
    """
928
    Transform from unconstrained space to the simplex of one additional
929
    dimension via a stick-breaking process.
930

931
    This transform arises as an iterated sigmoid transform in a stick-breaking
932
    construction of the `Dirichlet` distribution: the first logit is
933
    transformed via sigmoid to the first probability and the probability of
934
    everything else, and then the process recurses.
935

936
    This is bijective and appropriate for use in HMC; however it mixes
937
    coordinates together and is less appropriate for optimization.
938
    """
939

940
    domain = constraints.real_vector
941
    codomain = constraints.simplex
942
    bijective = True
943

944
    def __eq__(self, other):
945
        return isinstance(other, StickBreakingTransform)
946

947
    def _call(self, x):
948
        offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1)
949
        z = _clipped_sigmoid(x - offset.log())
950
        z_cumprod = (1 - z).cumprod(-1)
951
        y = pad(z, [0, 1], value=1) * pad(z_cumprod, [1, 0], value=1)
952
        return y
953

954
    def _inverse(self, y):
955
        y_crop = y[..., :-1]
956
        offset = y.shape[-1] - y.new_ones(y_crop.shape[-1]).cumsum(-1)
957
        sf = 1 - y_crop.cumsum(-1)
958
        # we clamp to make sure that sf is positive which sometimes does not
959
        # happen when y[-1] ~ 0 or y[:-1].sum() ~ 1
960
        sf = torch.clamp(sf, min=torch.finfo(y.dtype).tiny)
961
        x = y_crop.log() - sf.log() + offset.log()
962
        return x
963

964
    def log_abs_det_jacobian(self, x, y):
965
        offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1)
966
        x = x - offset.log()
967
        # use the identity 1 - sigmoid(x) = exp(-x) * sigmoid(x)
968
        detJ = (-x + F.logsigmoid(x) + y[..., :-1].log()).sum(-1)
969
        return detJ
970

971
    def forward_shape(self, shape):
972
        if len(shape) < 1:
973
            raise ValueError("Too few dimensions on input")
974
        return shape[:-1] + (shape[-1] + 1,)
975

976
    def inverse_shape(self, shape):
977
        if len(shape) < 1:
978
            raise ValueError("Too few dimensions on input")
979
        return shape[:-1] + (shape[-1] - 1,)
980

981

982
class LowerCholeskyTransform(Transform):
983
    """
984
    Transform from unconstrained matrices to lower-triangular matrices with
985
    nonnegative diagonal entries.
986

987
    This is useful for parameterizing positive definite matrices in terms of
988
    their Cholesky factorization.
989
    """
990

991
    domain = constraints.independent(constraints.real, 2)
992
    codomain = constraints.lower_cholesky
993

994
    def __eq__(self, other):
995
        return isinstance(other, LowerCholeskyTransform)
996

997
    def _call(self, x):
998
        return x.tril(-1) + x.diagonal(dim1=-2, dim2=-1).exp().diag_embed()
999

1000
    def _inverse(self, y):
1001
        return y.tril(-1) + y.diagonal(dim1=-2, dim2=-1).log().diag_embed()
1002

1003

1004
class PositiveDefiniteTransform(Transform):
1005
    """
1006
    Transform from unconstrained matrices to positive-definite matrices.
1007
    """
1008

1009
    domain = constraints.independent(constraints.real, 2)
1010
    codomain = constraints.positive_definite  # type: ignore[assignment]
1011

1012
    def __eq__(self, other):
1013
        return isinstance(other, PositiveDefiniteTransform)
1014

1015
    def _call(self, x):
1016
        x = LowerCholeskyTransform()(x)
1017
        return x @ x.mT
1018

1019
    def _inverse(self, y):
1020
        y = torch.linalg.cholesky(y)
1021
        return LowerCholeskyTransform().inv(y)
1022

1023

1024
class CatTransform(Transform):
1025
    """
1026
    Transform functor that applies a sequence of transforms `tseq`
1027
    component-wise to each submatrix at `dim`, of length `lengths[dim]`,
1028
    in a way compatible with :func:`torch.cat`.
1029

1030
    Example::
1031

1032
       x0 = torch.cat([torch.range(1, 10), torch.range(1, 10)], dim=0)
1033
       x = torch.cat([x0, x0], dim=0)
1034
       t0 = CatTransform([ExpTransform(), identity_transform], dim=0, lengths=[10, 10])
1035
       t = CatTransform([t0, t0], dim=0, lengths=[20, 20])
1036
       y = t(x)
1037
    """
1038

1039
    transforms: List[Transform]
1040

1041
    def __init__(self, tseq, dim=0, lengths=None, cache_size=0):
1042
        assert all(isinstance(t, Transform) for t in tseq)
1043
        if cache_size:
1044
            tseq = [t.with_cache(cache_size) for t in tseq]
1045
        super().__init__(cache_size=cache_size)
1046
        self.transforms = list(tseq)
1047
        if lengths is None:
1048
            lengths = [1] * len(self.transforms)
1049
        self.lengths = list(lengths)
1050
        assert len(self.lengths) == len(self.transforms)
1051
        self.dim = dim
1052

1053
    @lazy_property
1054
    def event_dim(self):
1055
        return max(t.event_dim for t in self.transforms)
1056

1057
    @lazy_property
1058
    def length(self):
1059
        return sum(self.lengths)
1060

1061
    def with_cache(self, cache_size=1):
1062
        if self._cache_size == cache_size:
1063
            return self
1064
        return CatTransform(self.transforms, self.dim, self.lengths, cache_size)
1065

1066
    def _call(self, x):
1067
        assert -x.dim() <= self.dim < x.dim()
1068
        assert x.size(self.dim) == self.length
1069
        yslices = []
1070
        start = 0
1071
        for trans, length in zip(self.transforms, self.lengths):
1072
            xslice = x.narrow(self.dim, start, length)
1073
            yslices.append(trans(xslice))
1074
            start = start + length  # avoid += for jit compat
1075
        return torch.cat(yslices, dim=self.dim)
1076

1077
    def _inverse(self, y):
1078
        assert -y.dim() <= self.dim < y.dim()
1079
        assert y.size(self.dim) == self.length
1080
        xslices = []
1081
        start = 0
1082
        for trans, length in zip(self.transforms, self.lengths):
1083
            yslice = y.narrow(self.dim, start, length)
1084
            xslices.append(trans.inv(yslice))
1085
            start = start + length  # avoid += for jit compat
1086
        return torch.cat(xslices, dim=self.dim)
1087

1088
    def log_abs_det_jacobian(self, x, y):
1089
        assert -x.dim() <= self.dim < x.dim()
1090
        assert x.size(self.dim) == self.length
1091
        assert -y.dim() <= self.dim < y.dim()
1092
        assert y.size(self.dim) == self.length
1093
        logdetjacs = []
1094
        start = 0
1095
        for trans, length in zip(self.transforms, self.lengths):
1096
            xslice = x.narrow(self.dim, start, length)
1097
            yslice = y.narrow(self.dim, start, length)
1098
            logdetjac = trans.log_abs_det_jacobian(xslice, yslice)
1099
            if trans.event_dim < self.event_dim:
1100
                logdetjac = _sum_rightmost(logdetjac, self.event_dim - trans.event_dim)
1101
            logdetjacs.append(logdetjac)
1102
            start = start + length  # avoid += for jit compat
1103
        # Decide whether to concatenate or sum.
1104
        dim = self.dim
1105
        if dim >= 0:
1106
            dim = dim - x.dim()
1107
        dim = dim + self.event_dim
1108
        if dim < 0:
1109
            return torch.cat(logdetjacs, dim=dim)
1110
        else:
1111
            return sum(logdetjacs)
1112

1113
    @property
1114
    def bijective(self):
1115
        return all(t.bijective for t in self.transforms)
1116

1117
    @constraints.dependent_property
1118
    def domain(self):
1119
        return constraints.cat(
1120
            [t.domain for t in self.transforms], self.dim, self.lengths
1121
        )
1122

1123
    @constraints.dependent_property
1124
    def codomain(self):
1125
        return constraints.cat(
1126
            [t.codomain for t in self.transforms], self.dim, self.lengths
1127
        )
1128

1129

1130
class StackTransform(Transform):
1131
    """
1132
    Transform functor that applies a sequence of transforms `tseq`
1133
    component-wise to each submatrix at `dim`
1134
    in a way compatible with :func:`torch.stack`.
1135

1136
    Example::
1137

1138
       x = torch.stack([torch.range(1, 10), torch.range(1, 10)], dim=1)
1139
       t = StackTransform([ExpTransform(), identity_transform], dim=1)
1140
       y = t(x)
1141
    """
1142

1143
    transforms: List[Transform]
1144

1145
    def __init__(self, tseq, dim=0, cache_size=0):
1146
        assert all(isinstance(t, Transform) for t in tseq)
1147
        if cache_size:
1148
            tseq = [t.with_cache(cache_size) for t in tseq]
1149
        super().__init__(cache_size=cache_size)
1150
        self.transforms = list(tseq)
1151
        self.dim = dim
1152

1153
    def with_cache(self, cache_size=1):
1154
        if self._cache_size == cache_size:
1155
            return self
1156
        return StackTransform(self.transforms, self.dim, cache_size)
1157

1158
    def _slice(self, z):
1159
        return [z.select(self.dim, i) for i in range(z.size(self.dim))]
1160

1161
    def _call(self, x):
1162
        assert -x.dim() <= self.dim < x.dim()
1163
        assert x.size(self.dim) == len(self.transforms)
1164
        yslices = []
1165
        for xslice, trans in zip(self._slice(x), self.transforms):
1166
            yslices.append(trans(xslice))
1167
        return torch.stack(yslices, dim=self.dim)
1168

1169
    def _inverse(self, y):
1170
        assert -y.dim() <= self.dim < y.dim()
1171
        assert y.size(self.dim) == len(self.transforms)
1172
        xslices = []
1173
        for yslice, trans in zip(self._slice(y), self.transforms):
1174
            xslices.append(trans.inv(yslice))
1175
        return torch.stack(xslices, dim=self.dim)
1176

1177
    def log_abs_det_jacobian(self, x, y):
1178
        assert -x.dim() <= self.dim < x.dim()
1179
        assert x.size(self.dim) == len(self.transforms)
1180
        assert -y.dim() <= self.dim < y.dim()
1181
        assert y.size(self.dim) == len(self.transforms)
1182
        logdetjacs = []
1183
        yslices = self._slice(y)
1184
        xslices = self._slice(x)
1185
        for xslice, yslice, trans in zip(xslices, yslices, self.transforms):
1186
            logdetjacs.append(trans.log_abs_det_jacobian(xslice, yslice))
1187
        return torch.stack(logdetjacs, dim=self.dim)
1188

1189
    @property
1190
    def bijective(self):
1191
        return all(t.bijective for t in self.transforms)
1192

1193
    @constraints.dependent_property
1194
    def domain(self):
1195
        return constraints.stack([t.domain for t in self.transforms], self.dim)
1196

1197
    @constraints.dependent_property
1198
    def codomain(self):
1199
        return constraints.stack([t.codomain for t in self.transforms], self.dim)
1200

1201

1202
class CumulativeDistributionTransform(Transform):
1203
    """
1204
    Transform via the cumulative distribution function of a probability distribution.
1205

1206
    Args:
1207
        distribution (Distribution): Distribution whose cumulative distribution function to use for
1208
            the transformation.
1209

1210
    Example::
1211

1212
        # Construct a Gaussian copula from a multivariate normal.
1213
        base_dist = MultivariateNormal(
1214
            loc=torch.zeros(2),
1215
            scale_tril=LKJCholesky(2).sample(),
1216
        )
1217
        transform = CumulativeDistributionTransform(Normal(0, 1))
1218
        copula = TransformedDistribution(base_dist, [transform])
1219
    """
1220

1221
    bijective = True
1222
    codomain = constraints.unit_interval
1223
    sign = +1
1224

1225
    def __init__(self, distribution, cache_size=0):
1226
        super().__init__(cache_size=cache_size)
1227
        self.distribution = distribution
1228

1229
    @property
1230
    def domain(self):
1231
        return self.distribution.support
1232

1233
    def _call(self, x):
1234
        return self.distribution.cdf(x)
1235

1236
    def _inverse(self, y):
1237
        return self.distribution.icdf(y)
1238

1239
    def log_abs_det_jacobian(self, x, y):
1240
        return self.distribution.log_prob(x)
1241

1242
    def with_cache(self, cache_size=1):
1243
        if self._cache_size == cache_size:
1244
            return self
1245
        return CumulativeDistributionTransform(self.distribution, cache_size=cache_size)
1246

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.