pytorch

Форк
0
/
parametrizations.py 
571 строка · 24.6 Кб
1
from enum import Enum, auto
2

3
import torch
4
from torch import Tensor
5
from ..utils import parametrize
6
from ..modules import Module
7
from .. import functional as F
8

9
from typing import Optional
10

11
__all__ = ['orthogonal', 'spectral_norm', 'weight_norm']
12

13

14
def _is_orthogonal(Q, eps=None):
15
    n, k = Q.size(-2), Q.size(-1)
16
    Id = torch.eye(k, dtype=Q.dtype, device=Q.device)
17
    # A reasonable eps, but not too large
18
    eps = 10. * n * torch.finfo(Q.dtype).eps
19
    return torch.allclose(Q.mH @ Q, Id, atol=eps)
20

21

22
def _make_orthogonal(A):
23
    """Assume that A is a tall matrix.
24

25
    Compute the Q factor s.t. A = QR (A may be complex) and diag(R) is real and non-negative.
26
    """
27
    X, tau = torch.geqrf(A)
28
    Q = torch.linalg.householder_product(X, tau)
29
    # The diagonal of X is the diagonal of R (which is always real) so we normalise by its signs
30
    Q *= X.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2)
31
    return Q
32

33

34
class _OrthMaps(Enum):
35
    matrix_exp = auto()
36
    cayley = auto()
37
    householder = auto()
38

39

40
class _Orthogonal(Module):
41
    base: Tensor
42

43
    def __init__(self,
44
                 weight,
45
                 orthogonal_map: _OrthMaps,
46
                 *,
47
                 use_trivialization=True) -> None:
48
        super().__init__()
49

50
        # Note [Householder complex]
51
        # For complex tensors, it is not possible to compute the tensor `tau` necessary for
52
        # linalg.householder_product from the reflectors.
53
        # To see this, note that the reflectors have a shape like:
54
        # 0 0 0
55
        # * 0 0
56
        # * * 0
57
        # which, for complex matrices, give n(n-1) (real) parameters. Now, you need n^2 parameters
58
        # to parametrize the unitary matrices. Saving tau on its own does not work either, because
59
        # not every combination of `(A, tau)` gives a unitary matrix, meaning that if we optimise
60
        # them as independent tensors we would not maintain the constraint
61
        # An equivalent reasoning holds for rectangular matrices
62
        if weight.is_complex() and orthogonal_map == _OrthMaps.householder:
63
            raise ValueError("The householder parametrization does not support complex tensors.")
64

65
        self.shape = weight.shape
66
        self.orthogonal_map = orthogonal_map
67
        if use_trivialization:
68
            self.register_buffer("base", None)
69

70
    def forward(self, X: torch.Tensor) -> torch.Tensor:
71
        n, k = X.size(-2), X.size(-1)
72
        transposed = n < k
73
        if transposed:
74
            X = X.mT
75
            n, k = k, n
76
        # Here n > k and X is a tall matrix
77
        if self.orthogonal_map == _OrthMaps.matrix_exp or self.orthogonal_map == _OrthMaps.cayley:
78
            # We just need n x k - k(k-1)/2 parameters
79
            X = X.tril()
80
            if n != k:
81
                # Embed into a square matrix
82
                X = torch.cat([X, X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1)
83
            A = X - X.mH
84
            # A is skew-symmetric (or skew-hermitian)
85
            if self.orthogonal_map == _OrthMaps.matrix_exp:
86
                Q = torch.matrix_exp(A)
87
            elif self.orthogonal_map == _OrthMaps.cayley:
88
                # Computes the Cayley retraction (I+A/2)(I-A/2)^{-1}
89
                Id = torch.eye(n, dtype=A.dtype, device=A.device)
90
                Q = torch.linalg.solve(torch.add(Id, A, alpha=-0.5), torch.add(Id, A, alpha=0.5))
91
            # Q is now orthogonal (or unitary) of size (..., n, n)
92
            if n != k:
93
                Q = Q[..., :k]
94
            # Q is now the size of the X (albeit perhaps transposed)
95
        else:
96
            # X is real here, as we do not support householder with complex numbers
97
            A = X.tril(diagonal=-1)
98
            tau = 2. / (1. + (A * A).sum(dim=-2))
99
            Q = torch.linalg.householder_product(A, tau)
100
            # The diagonal of X is 1's and -1's
101
            # We do not want to differentiate through this or update the diagonal of X hence the casting
102
            Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2)
103

104
        if hasattr(self, "base"):
105
            Q = self.base @ Q
106
        if transposed:
107
            Q = Q.mT
108
        return Q  # type: ignore[possibly-undefined]
109

110
    @torch.autograd.no_grad()
111
    def right_inverse(self, Q: torch.Tensor) -> torch.Tensor:
112
        if Q.shape != self.shape:
113
            raise ValueError(f"Expected a matrix or batch of matrices of shape {self.shape}. "
114
                             f"Got a tensor of shape {Q.shape}.")
115

116
        Q_init = Q
117
        n, k = Q.size(-2), Q.size(-1)
118
        transpose = n < k
119
        if transpose:
120
            Q = Q.mT
121
            n, k = k, n
122

123
        # We always make sure to always copy Q in every path
124
        if not hasattr(self, "base"):
125
            # Note [right_inverse expm cayley]
126
            # If we do not have use_trivialization=True, we just implement the inverse of the forward
127
            # map for the Householder. To see why, think that for the Cayley map,
128
            # we would need to find the matrix X \in R^{n x k} such that:
129
            # Y = torch.cat([X.tril(), X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1)
130
            # A = Y - Y.mH
131
            # cayley(A)[:, :k]
132
            # gives the original tensor. It is not clear how to do this.
133
            # Perhaps via some algebraic manipulation involving the QR like that of
134
            # Corollary 2.2 in Edelman, Arias and Smith?
135
            if self.orthogonal_map == _OrthMaps.cayley or self.orthogonal_map == _OrthMaps.matrix_exp:
136
                raise NotImplementedError("It is not possible to assign to the matrix exponential "
137
                                          "or the Cayley parametrizations when use_trivialization=False.")
138

139
            # If parametrization == _OrthMaps.householder, make Q orthogonal via the QR decomposition.
140
            # Here Q is always real because we do not support householder and complex matrices.
141
            # See note [Householder complex]
142
            A, tau = torch.geqrf(Q)
143
            # We want to have a decomposition X = QR with diag(R) > 0, as otherwise we could
144
            # decompose an orthogonal matrix Q as Q = (-Q)@(-Id), which is a valid QR decomposition
145
            # The diagonal of Q is the diagonal of R from the qr decomposition
146
            A.diagonal(dim1=-2, dim2=-1).sign_()
147
            # Equality with zero is ok because LAPACK returns exactly zero when it does not want
148
            # to use a particular reflection
149
            A.diagonal(dim1=-2, dim2=-1)[tau == 0.] *= -1
150
            return A.mT if transpose else A
151
        else:
152
            if n == k:
153
                # We check whether Q is orthogonal
154
                if not _is_orthogonal(Q):
155
                    Q = _make_orthogonal(Q)
156
                else:  # Is orthogonal
157
                    Q = Q.clone()
158
            else:
159
                # Complete Q into a full n x n orthogonal matrix
160
                N = torch.randn(*(Q.size()[:-2] + (n, n - k)), dtype=Q.dtype, device=Q.device)
161
                Q = torch.cat([Q, N], dim=-1)
162
                Q = _make_orthogonal(Q)
163
            self.base = Q
164

165
            # It is necessary to return the -Id, as we use the diagonal for the
166
            # Householder parametrization. Using -Id makes:
167
            # householder(torch.zeros(m,n)) == torch.eye(m,n)
168
            # Poor man's version of eye_like
169
            neg_Id = torch.zeros_like(Q_init)
170
            neg_Id.diagonal(dim1=-2, dim2=-1).fill_(-1.)
171
            return neg_Id
172

173

174
def orthogonal(module: Module,
175
               name: str = 'weight',
176
               orthogonal_map: Optional[str] = None,
177
               *,
178
               use_trivialization: bool = True) -> Module:
179
    r"""Apply an orthogonal or unitary parametrization to a matrix or a batch of matrices.
180

181
    Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, the parametrized
182
    matrix :math:`Q \in \mathbb{K}^{m \times n}` is **orthogonal** as
183

184
    .. math::
185

186
        \begin{align*}
187
            Q^{\text{H}}Q &= \mathrm{I}_n \mathrlap{\qquad \text{if }m \geq n}\\
188
            QQ^{\text{H}} &= \mathrm{I}_m \mathrlap{\qquad \text{if }m < n}
189
        \end{align*}
190

191
    where :math:`Q^{\text{H}}` is the conjugate transpose when :math:`Q` is complex
192
    and the transpose when :math:`Q` is real-valued, and
193
    :math:`\mathrm{I}_n` is the `n`-dimensional identity matrix.
194
    In plain words, :math:`Q` will have orthonormal columns whenever :math:`m \geq n`
195
    and orthonormal rows otherwise.
196

197
    If the tensor has more than two dimensions, we consider it as a batch of matrices of shape `(..., m, n)`.
198

199
    The matrix :math:`Q` may be parametrized via three different ``orthogonal_map`` in terms of the original tensor:
200

201
    - ``"matrix_exp"``/``"cayley"``:
202
      the :func:`~torch.matrix_exp` :math:`Q = \exp(A)` and the `Cayley map`_
203
      :math:`Q = (\mathrm{I}_n + A/2)(\mathrm{I}_n - A/2)^{-1}` are applied to a skew-symmetric
204
      :math:`A` to give an orthogonal matrix.
205
    - ``"householder"``: computes a product of Householder reflectors
206
      (:func:`~torch.linalg.householder_product`).
207

208
    ``"matrix_exp"``/``"cayley"`` often make the parametrized weight converge faster than
209
    ``"householder"``, but they are slower to compute for very thin or very wide matrices.
210

211
    If ``use_trivialization=True`` (default), the parametrization implements the "Dynamic Trivialization Framework",
212
    where an extra matrix :math:`B \in \mathbb{K}^{n \times n}` is stored under
213
    ``module.parametrizations.weight[0].base``. This helps the
214
    convergence of the parametrized layer at the expense of some extra memory use.
215
    See `Trivializations for Gradient-Based Optimization on Manifolds`_ .
216

217
    Initial value of :math:`Q`:
218
    If the original tensor is not parametrized and ``use_trivialization=True`` (default), the initial value
219
    of :math:`Q` is that of the original tensor if it is orthogonal (or unitary in the complex case)
220
    and it is orthogonalized via the QR decomposition otherwise (see :func:`torch.linalg.qr`).
221
    Same happens when it is not parametrized and ``orthogonal_map="householder"`` even when ``use_trivialization=False``.
222
    Otherwise, the initial value is the result of the composition of all the registered
223
    parametrizations applied to the original tensor.
224

225
    .. note::
226
        This function is implemented using the parametrization functionality
227
        in :func:`~torch.nn.utils.parametrize.register_parametrization`.
228

229

230
    .. _`Cayley map`: https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map
231
    .. _`Trivializations for Gradient-Based Optimization on Manifolds`: https://arxiv.org/abs/1909.09501
232

233
    Args:
234
        module (nn.Module): module on which to register the parametrization.
235
        name (str, optional): name of the tensor to make orthogonal. Default: ``"weight"``.
236
        orthogonal_map (str, optional): One of the following: ``"matrix_exp"``, ``"cayley"``, ``"householder"``.
237
            Default: ``"matrix_exp"`` if the matrix is square or complex, ``"householder"`` otherwise.
238
        use_trivialization (bool, optional): whether to use the dynamic trivialization framework.
239
            Default: ``True``.
240

241
    Returns:
242
        The original module with an orthogonal parametrization registered to the specified
243
        weight
244

245
    Example::
246

247
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
248
        >>> orth_linear = orthogonal(nn.Linear(20, 40))
249
        >>> orth_linear
250
        ParametrizedLinear(
251
        in_features=20, out_features=40, bias=True
252
        (parametrizations): ModuleDict(
253
            (weight): ParametrizationList(
254
            (0): _Orthogonal()
255
            )
256
        )
257
        )
258
        >>> # xdoctest: +IGNORE_WANT
259
        >>> Q = orth_linear.weight
260
        >>> torch.dist(Q.T @ Q, torch.eye(20))
261
        tensor(4.9332e-07)
262
    """
263
    weight = getattr(module, name, None)
264
    if not isinstance(weight, Tensor):
265
        raise ValueError(
266
            f"Module '{module}' has no parameter or buffer with name '{name}'"
267
        )
268

269
    # We could implement this for 1-dim tensors as the maps on the sphere
270
    # but I believe it'd bite more people than it'd help
271
    if weight.ndim < 2:
272
        raise ValueError("Expected a matrix or batch of matrices. "
273
                         f"Got a tensor of {weight.ndim} dimensions.")
274

275
    if orthogonal_map is None:
276
        orthogonal_map = "matrix_exp" if weight.size(-2) == weight.size(-1) or weight.is_complex() else "householder"
277

278
    orth_enum = getattr(_OrthMaps, orthogonal_map, None)
279
    if orth_enum is None:
280
        raise ValueError('orthogonal_map has to be one of "matrix_exp", "cayley", "householder". '
281
                         f'Got: {orthogonal_map}')
282
    orth = _Orthogonal(weight,
283
                       orth_enum,
284
                       use_trivialization=use_trivialization)
285
    parametrize.register_parametrization(module, name, orth, unsafe=True)
286
    return module
287

288

289
class _WeightNorm(Module):
290
    def __init__(
291
        self,
292
        dim: Optional[int] = 0,
293
    ) -> None:
294
        super().__init__()
295
        if dim is None:
296
            dim = -1
297
        self.dim = dim
298

299
    def forward(self, weight_g, weight_v):
300
        return torch._weight_norm(weight_v, weight_g, self.dim)
301

302
    def right_inverse(self, weight):
303
        weight_g = torch.norm_except_dim(weight, 2, self.dim)
304
        weight_v = weight
305

306
        return weight_g, weight_v
307

308

309
def weight_norm(module: Module, name: str = 'weight', dim: int = 0):
310
    r"""Apply weight normalization to a parameter in the given module.
311

312
    .. math::
313
         \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
314

315
    Weight normalization is a reparameterization that decouples the magnitude
316
    of a weight tensor from its direction. This replaces the parameter specified
317
    by :attr:`name` with two parameters: one specifying the magnitude
318
    and one specifying the direction.
319

320
    By default, with ``dim=0``, the norm is computed independently per output
321
    channel/plane. To compute a norm over the entire weight tensor, use
322
    ``dim=None``.
323

324
    See https://arxiv.org/abs/1602.07868
325

326
    Args:
327
        module (Module): containing module
328
        name (str, optional): name of weight parameter
329
        dim (int, optional): dimension over which to compute the norm
330

331
    Returns:
332
        The original module with the weight norm hook
333

334
    Example::
335

336
        >>> m = weight_norm(nn.Linear(20, 40), name='weight')
337
        >>> m
338
        ParametrizedLinear(
339
          in_features=20, out_features=40, bias=True
340
          (parametrizations): ModuleDict(
341
            (weight): ParametrizationList(
342
              (0): _WeightNorm()
343
            )
344
          )
345
        )
346
        >>> m.parametrizations.weight.original0.size()
347
        torch.Size([40, 1])
348
        >>> m.parametrizations.weight.original1.size()
349
        torch.Size([40, 20])
350

351
    """
352
    _weight_norm = _WeightNorm(dim)
353
    parametrize.register_parametrization(module, name, _weight_norm, unsafe=True)
354

355
    def _weight_norm_compat_hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
356
        g_key = f"{prefix}{name}_g"
357
        v_key = f"{prefix}{name}_v"
358
        if g_key in state_dict and v_key in state_dict:
359
            original0 = state_dict.pop(g_key)
360
            original1 = state_dict.pop(v_key)
361
            state_dict[f"{prefix}parametrizations.{name}.original0"] = original0
362
            state_dict[f"{prefix}parametrizations.{name}.original1"] = original1
363
    module._register_load_state_dict_pre_hook(_weight_norm_compat_hook)
364
    return module
365

366

367
class _SpectralNorm(Module):
368
    def __init__(
369
        self,
370
        weight: torch.Tensor,
371
        n_power_iterations: int = 1,
372
        dim: int = 0,
373
        eps: float = 1e-12
374
    ) -> None:
375
        super().__init__()
376
        ndim = weight.ndim
377
        if dim >= ndim or dim < -ndim:
378
            raise IndexError("Dimension out of range (expected to be in range of "
379
                             f"[-{ndim}, {ndim - 1}] but got {dim})")
380

381
        if n_power_iterations <= 0:
382
            raise ValueError('Expected n_power_iterations to be positive, but '
383
                             f'got n_power_iterations={n_power_iterations}')
384
        self.dim = dim if dim >= 0 else dim + ndim
385
        self.eps = eps
386
        if ndim > 1:
387
            # For ndim == 1 we do not need to approximate anything (see _SpectralNorm.forward)
388
            self.n_power_iterations = n_power_iterations
389
            weight_mat = self._reshape_weight_to_matrix(weight)
390
            h, w = weight_mat.size()
391

392
            u = weight_mat.new_empty(h).normal_(0, 1)
393
            v = weight_mat.new_empty(w).normal_(0, 1)
394
            self.register_buffer('_u', F.normalize(u, dim=0, eps=self.eps))
395
            self.register_buffer('_v', F.normalize(v, dim=0, eps=self.eps))
396

397
            # Start with u, v initialized to some reasonable values by performing a number
398
            # of iterations of the power method
399
            self._power_method(weight_mat, 15)
400

401
    def _reshape_weight_to_matrix(self, weight: torch.Tensor) -> torch.Tensor:
402
        # Precondition
403
        assert weight.ndim > 1
404

405
        if self.dim != 0:
406
            # permute dim to front
407
            weight = weight.permute(self.dim, *(d for d in range(weight.dim()) if d != self.dim))
408

409
        return weight.flatten(1)
410

411
    @torch.autograd.no_grad()
412
    def _power_method(self, weight_mat: torch.Tensor, n_power_iterations: int) -> None:
413
        # See original note at torch/nn/utils/spectral_norm.py
414
        # NB: If `do_power_iteration` is set, the `u` and `v` vectors are
415
        #     updated in power iteration **in-place**. This is very important
416
        #     because in `DataParallel` forward, the vectors (being buffers) are
417
        #     broadcast from the parallelized module to each module replica,
418
        #     which is a new module object created on the fly. And each replica
419
        #     runs its own spectral norm power iteration. So simply assigning
420
        #     the updated vectors to the module this function runs on will cause
421
        #     the update to be lost forever. And the next time the parallelized
422
        #     module is replicated, the same randomly initialized vectors are
423
        #     broadcast and used!
424
        #
425
        #     Therefore, to make the change propagate back, we rely on two
426
        #     important behaviors (also enforced via tests):
427
        #       1. `DataParallel` doesn't clone storage if the broadcast tensor
428
        #          is already on correct device; and it makes sure that the
429
        #          parallelized module is already on `device[0]`.
430
        #       2. If the out tensor in `out=` kwarg has correct shape, it will
431
        #          just fill in the values.
432
        #     Therefore, since the same power iteration is performed on all
433
        #     devices, simply updating the tensors in-place will make sure that
434
        #     the module replica on `device[0]` will update the _u vector on the
435
        #     parallelized module (by shared storage).
436
        #
437
        #    However, after we update `u` and `v` in-place, we need to **clone**
438
        #    them before using them to normalize the weight. This is to support
439
        #    backproping through two forward passes, e.g., the common pattern in
440
        #    GAN training: loss = D(real) - D(fake). Otherwise, engine will
441
        #    complain that variables needed to do backward for the first forward
442
        #    (i.e., the `u` and `v` vectors) are changed in the second forward.
443

444
        # Precondition
445
        assert weight_mat.ndim > 1
446

447
        for _ in range(n_power_iterations):
448
            # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
449
            # are the first left and right singular vectors.
450
            # This power iteration produces approximations of `u` and `v`.
451
            self._u = F.normalize(torch.mv(weight_mat, self._v),      # type: ignore[has-type]
452
                                  dim=0, eps=self.eps, out=self._u)   # type: ignore[has-type]
453
            self._v = F.normalize(torch.mv(weight_mat.t(), self._u),
454
                                  dim=0, eps=self.eps, out=self._v)   # type: ignore[has-type]
455

456
    def forward(self, weight: torch.Tensor) -> torch.Tensor:
457
        if weight.ndim == 1:
458
            # Faster and more exact path, no need to approximate anything
459
            return F.normalize(weight, dim=0, eps=self.eps)
460
        else:
461
            weight_mat = self._reshape_weight_to_matrix(weight)
462
            if self.training:
463
                self._power_method(weight_mat, self.n_power_iterations)
464
            # See above on why we need to clone
465
            u = self._u.clone(memory_format=torch.contiguous_format)
466
            v = self._v.clone(memory_format=torch.contiguous_format)
467
            # The proper way of computing this should be through F.bilinear, but
468
            # it seems to have some efficiency issues:
469
            # https://github.com/pytorch/pytorch/issues/58093
470
            sigma = torch.dot(u, torch.mv(weight_mat, v))
471
            return weight / sigma
472

473
    def right_inverse(self, value: torch.Tensor) -> torch.Tensor:
474
        # we may want to assert here that the passed value already
475
        # satisfies constraints
476
        return value
477

478

479
def spectral_norm(module: Module,
480
                  name: str = 'weight',
481
                  n_power_iterations: int = 1,
482
                  eps: float = 1e-12,
483
                  dim: Optional[int] = None) -> Module:
484
    r"""Apply spectral normalization to a parameter in the given module.
485

486
    .. math::
487
        \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})},
488
        \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
489

490
    When applied on a vector, it simplifies to
491

492
    .. math::
493
        \mathbf{x}_{SN} = \dfrac{\mathbf{x}}{\|\mathbf{x}\|_2}
494

495
    Spectral normalization stabilizes the training of discriminators (critics)
496
    in Generative Adversarial Networks (GANs) by reducing the Lipschitz constant
497
    of the model. :math:`\sigma` is approximated performing one iteration of the
498
    `power method`_ every time the weight is accessed. If the dimension of the
499
    weight tensor is greater than 2, it is reshaped to 2D in power iteration
500
    method to get spectral norm.
501

502

503
    See `Spectral Normalization for Generative Adversarial Networks`_ .
504

505
    .. _`power method`: https://en.wikipedia.org/wiki/Power_iteration
506
    .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
507

508
    .. note::
509
        This function is implemented using the parametrization functionality
510
        in :func:`~torch.nn.utils.parametrize.register_parametrization`. It is a
511
        reimplementation of :func:`torch.nn.utils.spectral_norm`.
512

513
    .. note::
514
        When this constraint is registered, the singular vectors associated to the largest
515
        singular value are estimated rather than sampled at random. These are then updated
516
        performing :attr:`n_power_iterations` of the `power method`_ whenever the tensor
517
        is accessed with the module on `training` mode.
518

519
    .. note::
520
        If the `_SpectralNorm` module, i.e., `module.parametrization.weight[idx]`,
521
        is in training mode on removal, it will perform another power iteration.
522
        If you'd like to avoid this iteration, set the module to eval mode
523
        before its removal.
524

525
    Args:
526
        module (nn.Module): containing module
527
        name (str, optional): name of weight parameter. Default: ``"weight"``.
528
        n_power_iterations (int, optional): number of power iterations to
529
            calculate spectral norm. Default: ``1``.
530
        eps (float, optional): epsilon for numerical stability in
531
            calculating norms. Default: ``1e-12``.
532
        dim (int, optional): dimension corresponding to number of outputs.
533
            Default: ``0``, except for modules that are instances of
534
            ConvTranspose{1,2,3}d, when it is ``1``
535

536
    Returns:
537
        The original module with a new parametrization registered to the specified
538
        weight
539

540
    Example::
541

542
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
543
        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
544
        >>> snm = spectral_norm(nn.Linear(20, 40))
545
        >>> snm
546
        ParametrizedLinear(
547
          in_features=20, out_features=40, bias=True
548
          (parametrizations): ModuleDict(
549
            (weight): ParametrizationList(
550
              (0): _SpectralNorm()
551
            )
552
          )
553
        )
554
        >>> torch.linalg.matrix_norm(snm.weight, 2)
555
        tensor(1.0081, grad_fn=<AmaxBackward0>)
556
    """
557
    weight = getattr(module, name, None)
558
    if not isinstance(weight, Tensor):
559
        raise ValueError(
560
            f"Module '{module}' has no parameter or buffer with name '{name}'"
561
        )
562

563
    if dim is None:
564
        if isinstance(module, (torch.nn.ConvTranspose1d,
565
                               torch.nn.ConvTranspose2d,
566
                               torch.nn.ConvTranspose3d)):
567
            dim = 1
568
        else:
569
            dim = 0
570
    parametrize.register_parametrization(module, name, _SpectralNorm(weight, n_power_iterations, dim, eps))
571
    return module
572

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.