1
from enum import Enum, auto
4
from torch import Tensor
5
from ..utils import parametrize
6
from ..modules import Module
7
from .. import functional as F
9
from typing import Optional
11
__all__ = ['orthogonal', 'spectral_norm', 'weight_norm']
14
def _is_orthogonal(Q, eps=None):
15
n, k = Q.size(-2), Q.size(-1)
16
Id = torch.eye(k, dtype=Q.dtype, device=Q.device)
17
# A reasonable eps, but not too large
18
eps = 10. * n * torch.finfo(Q.dtype).eps
19
return torch.allclose(Q.mH @ Q, Id, atol=eps)
22
def _make_orthogonal(A):
23
"""Assume that A is a tall matrix.
25
Compute the Q factor s.t. A = QR (A may be complex) and diag(R) is real and non-negative.
27
X, tau = torch.geqrf(A)
28
Q = torch.linalg.householder_product(X, tau)
29
# The diagonal of X is the diagonal of R (which is always real) so we normalise by its signs
30
Q *= X.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2)
40
class _Orthogonal(Module):
45
orthogonal_map: _OrthMaps,
47
use_trivialization=True) -> None:
50
# Note [Householder complex]
51
# For complex tensors, it is not possible to compute the tensor `tau` necessary for
52
# linalg.householder_product from the reflectors.
53
# To see this, note that the reflectors have a shape like:
57
# which, for complex matrices, give n(n-1) (real) parameters. Now, you need n^2 parameters
58
# to parametrize the unitary matrices. Saving tau on its own does not work either, because
59
# not every combination of `(A, tau)` gives a unitary matrix, meaning that if we optimise
60
# them as independent tensors we would not maintain the constraint
61
# An equivalent reasoning holds for rectangular matrices
62
if weight.is_complex() and orthogonal_map == _OrthMaps.householder:
63
raise ValueError("The householder parametrization does not support complex tensors.")
65
self.shape = weight.shape
66
self.orthogonal_map = orthogonal_map
67
if use_trivialization:
68
self.register_buffer("base", None)
70
def forward(self, X: torch.Tensor) -> torch.Tensor:
71
n, k = X.size(-2), X.size(-1)
76
# Here n > k and X is a tall matrix
77
if self.orthogonal_map == _OrthMaps.matrix_exp or self.orthogonal_map == _OrthMaps.cayley:
78
# We just need n x k - k(k-1)/2 parameters
81
# Embed into a square matrix
82
X = torch.cat([X, X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1)
84
# A is skew-symmetric (or skew-hermitian)
85
if self.orthogonal_map == _OrthMaps.matrix_exp:
86
Q = torch.matrix_exp(A)
87
elif self.orthogonal_map == _OrthMaps.cayley:
88
# Computes the Cayley retraction (I+A/2)(I-A/2)^{-1}
89
Id = torch.eye(n, dtype=A.dtype, device=A.device)
90
Q = torch.linalg.solve(torch.add(Id, A, alpha=-0.5), torch.add(Id, A, alpha=0.5))
91
# Q is now orthogonal (or unitary) of size (..., n, n)
94
# Q is now the size of the X (albeit perhaps transposed)
96
# X is real here, as we do not support householder with complex numbers
97
A = X.tril(diagonal=-1)
98
tau = 2. / (1. + (A * A).sum(dim=-2))
99
Q = torch.linalg.householder_product(A, tau)
100
# The diagonal of X is 1's and -1's
101
# We do not want to differentiate through this or update the diagonal of X hence the casting
102
Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2)
104
if hasattr(self, "base"):
108
return Q # type: ignore[possibly-undefined]
110
@torch.autograd.no_grad()
111
def right_inverse(self, Q: torch.Tensor) -> torch.Tensor:
112
if Q.shape != self.shape:
113
raise ValueError(f"Expected a matrix or batch of matrices of shape {self.shape}. "
114
f"Got a tensor of shape {Q.shape}.")
117
n, k = Q.size(-2), Q.size(-1)
123
# We always make sure to always copy Q in every path
124
if not hasattr(self, "base"):
125
# Note [right_inverse expm cayley]
126
# If we do not have use_trivialization=True, we just implement the inverse of the forward
127
# map for the Householder. To see why, think that for the Cayley map,
128
# we would need to find the matrix X \in R^{n x k} such that:
129
# Y = torch.cat([X.tril(), X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1)
132
# gives the original tensor. It is not clear how to do this.
133
# Perhaps via some algebraic manipulation involving the QR like that of
134
# Corollary 2.2 in Edelman, Arias and Smith?
135
if self.orthogonal_map == _OrthMaps.cayley or self.orthogonal_map == _OrthMaps.matrix_exp:
136
raise NotImplementedError("It is not possible to assign to the matrix exponential "
137
"or the Cayley parametrizations when use_trivialization=False.")
139
# If parametrization == _OrthMaps.householder, make Q orthogonal via the QR decomposition.
140
# Here Q is always real because we do not support householder and complex matrices.
141
# See note [Householder complex]
142
A, tau = torch.geqrf(Q)
143
# We want to have a decomposition X = QR with diag(R) > 0, as otherwise we could
144
# decompose an orthogonal matrix Q as Q = (-Q)@(-Id), which is a valid QR decomposition
145
# The diagonal of Q is the diagonal of R from the qr decomposition
146
A.diagonal(dim1=-2, dim2=-1).sign_()
147
# Equality with zero is ok because LAPACK returns exactly zero when it does not want
148
# to use a particular reflection
149
A.diagonal(dim1=-2, dim2=-1)[tau == 0.] *= -1
150
return A.mT if transpose else A
153
# We check whether Q is orthogonal
154
if not _is_orthogonal(Q):
155
Q = _make_orthogonal(Q)
156
else: # Is orthogonal
159
# Complete Q into a full n x n orthogonal matrix
160
N = torch.randn(*(Q.size()[:-2] + (n, n - k)), dtype=Q.dtype, device=Q.device)
161
Q = torch.cat([Q, N], dim=-1)
162
Q = _make_orthogonal(Q)
165
# It is necessary to return the -Id, as we use the diagonal for the
166
# Householder parametrization. Using -Id makes:
167
# householder(torch.zeros(m,n)) == torch.eye(m,n)
168
# Poor man's version of eye_like
169
neg_Id = torch.zeros_like(Q_init)
170
neg_Id.diagonal(dim1=-2, dim2=-1).fill_(-1.)
174
def orthogonal(module: Module,
175
name: str = 'weight',
176
orthogonal_map: Optional[str] = None,
178
use_trivialization: bool = True) -> Module:
179
r"""Apply an orthogonal or unitary parametrization to a matrix or a batch of matrices.
181
Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, the parametrized
182
matrix :math:`Q \in \mathbb{K}^{m \times n}` is **orthogonal** as
187
Q^{\text{H}}Q &= \mathrm{I}_n \mathrlap{\qquad \text{if }m \geq n}\\
188
QQ^{\text{H}} &= \mathrm{I}_m \mathrlap{\qquad \text{if }m < n}
191
where :math:`Q^{\text{H}}` is the conjugate transpose when :math:`Q` is complex
192
and the transpose when :math:`Q` is real-valued, and
193
:math:`\mathrm{I}_n` is the `n`-dimensional identity matrix.
194
In plain words, :math:`Q` will have orthonormal columns whenever :math:`m \geq n`
195
and orthonormal rows otherwise.
197
If the tensor has more than two dimensions, we consider it as a batch of matrices of shape `(..., m, n)`.
199
The matrix :math:`Q` may be parametrized via three different ``orthogonal_map`` in terms of the original tensor:
201
- ``"matrix_exp"``/``"cayley"``:
202
the :func:`~torch.matrix_exp` :math:`Q = \exp(A)` and the `Cayley map`_
203
:math:`Q = (\mathrm{I}_n + A/2)(\mathrm{I}_n - A/2)^{-1}` are applied to a skew-symmetric
204
:math:`A` to give an orthogonal matrix.
205
- ``"householder"``: computes a product of Householder reflectors
206
(:func:`~torch.linalg.householder_product`).
208
``"matrix_exp"``/``"cayley"`` often make the parametrized weight converge faster than
209
``"householder"``, but they are slower to compute for very thin or very wide matrices.
211
If ``use_trivialization=True`` (default), the parametrization implements the "Dynamic Trivialization Framework",
212
where an extra matrix :math:`B \in \mathbb{K}^{n \times n}` is stored under
213
``module.parametrizations.weight[0].base``. This helps the
214
convergence of the parametrized layer at the expense of some extra memory use.
215
See `Trivializations for Gradient-Based Optimization on Manifolds`_ .
217
Initial value of :math:`Q`:
218
If the original tensor is not parametrized and ``use_trivialization=True`` (default), the initial value
219
of :math:`Q` is that of the original tensor if it is orthogonal (or unitary in the complex case)
220
and it is orthogonalized via the QR decomposition otherwise (see :func:`torch.linalg.qr`).
221
Same happens when it is not parametrized and ``orthogonal_map="householder"`` even when ``use_trivialization=False``.
222
Otherwise, the initial value is the result of the composition of all the registered
223
parametrizations applied to the original tensor.
226
This function is implemented using the parametrization functionality
227
in :func:`~torch.nn.utils.parametrize.register_parametrization`.
230
.. _`Cayley map`: https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map
231
.. _`Trivializations for Gradient-Based Optimization on Manifolds`: https://arxiv.org/abs/1909.09501
234
module (nn.Module): module on which to register the parametrization.
235
name (str, optional): name of the tensor to make orthogonal. Default: ``"weight"``.
236
orthogonal_map (str, optional): One of the following: ``"matrix_exp"``, ``"cayley"``, ``"householder"``.
237
Default: ``"matrix_exp"`` if the matrix is square or complex, ``"householder"`` otherwise.
238
use_trivialization (bool, optional): whether to use the dynamic trivialization framework.
242
The original module with an orthogonal parametrization registered to the specified
247
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
248
>>> orth_linear = orthogonal(nn.Linear(20, 40))
251
in_features=20, out_features=40, bias=True
252
(parametrizations): ModuleDict(
253
(weight): ParametrizationList(
258
>>> # xdoctest: +IGNORE_WANT
259
>>> Q = orth_linear.weight
260
>>> torch.dist(Q.T @ Q, torch.eye(20))
263
weight = getattr(module, name, None)
264
if not isinstance(weight, Tensor):
266
f"Module '{module}' has no parameter or buffer with name '{name}'"
269
# We could implement this for 1-dim tensors as the maps on the sphere
270
# but I believe it'd bite more people than it'd help
272
raise ValueError("Expected a matrix or batch of matrices. "
273
f"Got a tensor of {weight.ndim} dimensions.")
275
if orthogonal_map is None:
276
orthogonal_map = "matrix_exp" if weight.size(-2) == weight.size(-1) or weight.is_complex() else "householder"
278
orth_enum = getattr(_OrthMaps, orthogonal_map, None)
279
if orth_enum is None:
280
raise ValueError('orthogonal_map has to be one of "matrix_exp", "cayley", "householder". '
281
f'Got: {orthogonal_map}')
282
orth = _Orthogonal(weight,
284
use_trivialization=use_trivialization)
285
parametrize.register_parametrization(module, name, orth, unsafe=True)
289
class _WeightNorm(Module):
292
dim: Optional[int] = 0,
299
def forward(self, weight_g, weight_v):
300
return torch._weight_norm(weight_v, weight_g, self.dim)
302
def right_inverse(self, weight):
303
weight_g = torch.norm_except_dim(weight, 2, self.dim)
306
return weight_g, weight_v
309
def weight_norm(module: Module, name: str = 'weight', dim: int = 0):
310
r"""Apply weight normalization to a parameter in the given module.
313
\mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
315
Weight normalization is a reparameterization that decouples the magnitude
316
of a weight tensor from its direction. This replaces the parameter specified
317
by :attr:`name` with two parameters: one specifying the magnitude
318
and one specifying the direction.
320
By default, with ``dim=0``, the norm is computed independently per output
321
channel/plane. To compute a norm over the entire weight tensor, use
324
See https://arxiv.org/abs/1602.07868
327
module (Module): containing module
328
name (str, optional): name of weight parameter
329
dim (int, optional): dimension over which to compute the norm
332
The original module with the weight norm hook
336
>>> m = weight_norm(nn.Linear(20, 40), name='weight')
339
in_features=20, out_features=40, bias=True
340
(parametrizations): ModuleDict(
341
(weight): ParametrizationList(
346
>>> m.parametrizations.weight.original0.size()
348
>>> m.parametrizations.weight.original1.size()
352
_weight_norm = _WeightNorm(dim)
353
parametrize.register_parametrization(module, name, _weight_norm, unsafe=True)
355
def _weight_norm_compat_hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
356
g_key = f"{prefix}{name}_g"
357
v_key = f"{prefix}{name}_v"
358
if g_key in state_dict and v_key in state_dict:
359
original0 = state_dict.pop(g_key)
360
original1 = state_dict.pop(v_key)
361
state_dict[f"{prefix}parametrizations.{name}.original0"] = original0
362
state_dict[f"{prefix}parametrizations.{name}.original1"] = original1
363
module._register_load_state_dict_pre_hook(_weight_norm_compat_hook)
367
class _SpectralNorm(Module):
370
weight: torch.Tensor,
371
n_power_iterations: int = 1,
377
if dim >= ndim or dim < -ndim:
378
raise IndexError("Dimension out of range (expected to be in range of "
379
f"[-{ndim}, {ndim - 1}] but got {dim})")
381
if n_power_iterations <= 0:
382
raise ValueError('Expected n_power_iterations to be positive, but '
383
f'got n_power_iterations={n_power_iterations}')
384
self.dim = dim if dim >= 0 else dim + ndim
387
# For ndim == 1 we do not need to approximate anything (see _SpectralNorm.forward)
388
self.n_power_iterations = n_power_iterations
389
weight_mat = self._reshape_weight_to_matrix(weight)
390
h, w = weight_mat.size()
392
u = weight_mat.new_empty(h).normal_(0, 1)
393
v = weight_mat.new_empty(w).normal_(0, 1)
394
self.register_buffer('_u', F.normalize(u, dim=0, eps=self.eps))
395
self.register_buffer('_v', F.normalize(v, dim=0, eps=self.eps))
397
# Start with u, v initialized to some reasonable values by performing a number
398
# of iterations of the power method
399
self._power_method(weight_mat, 15)
401
def _reshape_weight_to_matrix(self, weight: torch.Tensor) -> torch.Tensor:
403
assert weight.ndim > 1
406
# permute dim to front
407
weight = weight.permute(self.dim, *(d for d in range(weight.dim()) if d != self.dim))
409
return weight.flatten(1)
411
@torch.autograd.no_grad()
412
def _power_method(self, weight_mat: torch.Tensor, n_power_iterations: int) -> None:
413
# See original note at torch/nn/utils/spectral_norm.py
414
# NB: If `do_power_iteration` is set, the `u` and `v` vectors are
415
# updated in power iteration **in-place**. This is very important
416
# because in `DataParallel` forward, the vectors (being buffers) are
417
# broadcast from the parallelized module to each module replica,
418
# which is a new module object created on the fly. And each replica
419
# runs its own spectral norm power iteration. So simply assigning
420
# the updated vectors to the module this function runs on will cause
421
# the update to be lost forever. And the next time the parallelized
422
# module is replicated, the same randomly initialized vectors are
423
# broadcast and used!
425
# Therefore, to make the change propagate back, we rely on two
426
# important behaviors (also enforced via tests):
427
# 1. `DataParallel` doesn't clone storage if the broadcast tensor
428
# is already on correct device; and it makes sure that the
429
# parallelized module is already on `device[0]`.
430
# 2. If the out tensor in `out=` kwarg has correct shape, it will
431
# just fill in the values.
432
# Therefore, since the same power iteration is performed on all
433
# devices, simply updating the tensors in-place will make sure that
434
# the module replica on `device[0]` will update the _u vector on the
435
# parallelized module (by shared storage).
437
# However, after we update `u` and `v` in-place, we need to **clone**
438
# them before using them to normalize the weight. This is to support
439
# backproping through two forward passes, e.g., the common pattern in
440
# GAN training: loss = D(real) - D(fake). Otherwise, engine will
441
# complain that variables needed to do backward for the first forward
442
# (i.e., the `u` and `v` vectors) are changed in the second forward.
445
assert weight_mat.ndim > 1
447
for _ in range(n_power_iterations):
448
# Spectral norm of weight equals to `u^T W v`, where `u` and `v`
449
# are the first left and right singular vectors.
450
# This power iteration produces approximations of `u` and `v`.
451
self._u = F.normalize(torch.mv(weight_mat, self._v), # type: ignore[has-type]
452
dim=0, eps=self.eps, out=self._u) # type: ignore[has-type]
453
self._v = F.normalize(torch.mv(weight_mat.t(), self._u),
454
dim=0, eps=self.eps, out=self._v) # type: ignore[has-type]
456
def forward(self, weight: torch.Tensor) -> torch.Tensor:
458
# Faster and more exact path, no need to approximate anything
459
return F.normalize(weight, dim=0, eps=self.eps)
461
weight_mat = self._reshape_weight_to_matrix(weight)
463
self._power_method(weight_mat, self.n_power_iterations)
464
# See above on why we need to clone
465
u = self._u.clone(memory_format=torch.contiguous_format)
466
v = self._v.clone(memory_format=torch.contiguous_format)
467
# The proper way of computing this should be through F.bilinear, but
468
# it seems to have some efficiency issues:
469
# https://github.com/pytorch/pytorch/issues/58093
470
sigma = torch.dot(u, torch.mv(weight_mat, v))
471
return weight / sigma
473
def right_inverse(self, value: torch.Tensor) -> torch.Tensor:
474
# we may want to assert here that the passed value already
475
# satisfies constraints
479
def spectral_norm(module: Module,
480
name: str = 'weight',
481
n_power_iterations: int = 1,
483
dim: Optional[int] = None) -> Module:
484
r"""Apply spectral normalization to a parameter in the given module.
487
\mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})},
488
\sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
490
When applied on a vector, it simplifies to
493
\mathbf{x}_{SN} = \dfrac{\mathbf{x}}{\|\mathbf{x}\|_2}
495
Spectral normalization stabilizes the training of discriminators (critics)
496
in Generative Adversarial Networks (GANs) by reducing the Lipschitz constant
497
of the model. :math:`\sigma` is approximated performing one iteration of the
498
`power method`_ every time the weight is accessed. If the dimension of the
499
weight tensor is greater than 2, it is reshaped to 2D in power iteration
500
method to get spectral norm.
503
See `Spectral Normalization for Generative Adversarial Networks`_ .
505
.. _`power method`: https://en.wikipedia.org/wiki/Power_iteration
506
.. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
509
This function is implemented using the parametrization functionality
510
in :func:`~torch.nn.utils.parametrize.register_parametrization`. It is a
511
reimplementation of :func:`torch.nn.utils.spectral_norm`.
514
When this constraint is registered, the singular vectors associated to the largest
515
singular value are estimated rather than sampled at random. These are then updated
516
performing :attr:`n_power_iterations` of the `power method`_ whenever the tensor
517
is accessed with the module on `training` mode.
520
If the `_SpectralNorm` module, i.e., `module.parametrization.weight[idx]`,
521
is in training mode on removal, it will perform another power iteration.
522
If you'd like to avoid this iteration, set the module to eval mode
526
module (nn.Module): containing module
527
name (str, optional): name of weight parameter. Default: ``"weight"``.
528
n_power_iterations (int, optional): number of power iterations to
529
calculate spectral norm. Default: ``1``.
530
eps (float, optional): epsilon for numerical stability in
531
calculating norms. Default: ``1e-12``.
532
dim (int, optional): dimension corresponding to number of outputs.
533
Default: ``0``, except for modules that are instances of
534
ConvTranspose{1,2,3}d, when it is ``1``
537
The original module with a new parametrization registered to the specified
542
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
543
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
544
>>> snm = spectral_norm(nn.Linear(20, 40))
547
in_features=20, out_features=40, bias=True
548
(parametrizations): ModuleDict(
549
(weight): ParametrizationList(
554
>>> torch.linalg.matrix_norm(snm.weight, 2)
555
tensor(1.0081, grad_fn=<AmaxBackward0>)
557
weight = getattr(module, name, None)
558
if not isinstance(weight, Tensor):
560
f"Module '{module}' has no parameter or buffer with name '{name}'"
564
if isinstance(module, (torch.nn.ConvTranspose1d,
565
torch.nn.ConvTranspose2d,
566
torch.nn.ConvTranspose3d)):
570
parametrize.register_parametrization(module, name, _SpectralNorm(weight, n_power_iterations, dim, eps))