pytorch

Форк
0
/
test_parametrization.py 
1905 строк · 81.2 Кб
1
# Owner(s): ["module: nn"]
2
import pickle
3
from copy import deepcopy
4
from itertools import product
5

6
import torch
7
import torch.nn as nn
8
import torch.nn.functional as F
9
import torch.nn.init as init
10
import torch.nn.utils.parametrize as parametrize
11
from torch import Tensor
12
from torch.__future__ import get_swap_module_params_on_conversion
13
from torch.nn import Buffer, Parameter
14
from torch.testing._internal.common_cuda import TEST_MULTIGPU
15
from torch.testing._internal.common_device_type import instantiate_device_type_tests
16
from torch.testing._internal.common_nn import NNTestCase
17
from torch.testing._internal.common_utils import (
18
    gradcheck,
19
    instantiate_parametrized_tests,
20
    run_tests,
21
    set_default_dtype,
22
    skipIfNoLapack,
23
    skipIfTorchDynamo,
24
    swap,
25
    TemporaryFileName,
26
)
27
from torch.testing._internal.two_tensor import TwoTensor
28

29

30
class TestNNParametrization(NNTestCase):
31
    _do_cuda_memory_leak_check = True
32
    _do_cuda_non_default_stream = True
33

34
    # FIXME: Rewrite this test using functions not depending on LAPACK
35
    #        and remove the `@skipIfNoLapack` (see #70995)
36
    # torch/nn/utils/parametrize
37
    @skipIfNoLapack
38
    @swap([True, False])
39
    def test_register_and_remove_parametrization(self):
40
        r"""Test that it is possible to add a few parametrizations
41
        on a parameter or a buffer and that removing them restores the initial state
42
        It also tests that backpropagating through them works as expected
43
        """
44

45
        # Define a couple matrix parametrizations
46
        class Skew(nn.Module):
47
            def forward(self, X):
48
                X = X.tril(-1)
49
                return X - X.T
50

51
        class Orthogonal(nn.Module):
52
            def forward(self, X):
53
                # Cayley map
54
                # If X is skew-symmetric it returns an orthogonal matrix
55
                Id = torch.eye(X.size(0), device=X.device)
56
                # We call contiguous because solve returns a tensor with strides that are Fortran-contiguous
57
                # and autograd raises a performance warning.
58
                # This happens when we remove the parametrization with leave_parametrized=True,
59
                # which does a set_ with a non-contiguous tensor while the gradient is contiguous
60
                return torch.linalg.solve(Id + X, Id - X).contiguous()
61

62
        class Resize(nn.Module):
63
            def forward(self, X):
64
                return X[[0]]
65

66
        class NoResize(nn.Module):
67
            def forward(self, X):
68
                return X
69

70
        # Define a couple vector parametrizations
71
        class FirstZero(nn.Module):
72
            def forward(self, x):
73
                return torch.cat([x.new_zeros(1), x[1:]])
74

75
        class LastZero(nn.Module):
76
            def forward(self, x):
77
                return torch.cat([x[:-1], x.new_zeros(1)])
78

79
        model = nn.Linear(8, 8)
80
        initial_weight_id = id(model.weight)
81
        initial_bias_id = id(model.bias)
82
        initial_model = deepcopy(model)
83

84
        # Test unsafe flag
85
        with self.assertRaisesRegex(
86
            ValueError,
87
            "Registering a parametrization may not change the shape of the tensor",
88
        ):
89
            parametrize.register_parametrization(
90
                model, "weight", Resize()
91
            )  # default unsafe = False
92
            model(torch.ones(8, 8))
93

94
        # One parametrization with unsafe=True
95
        parametrize.register_parametrization(model, "weight", Resize(), unsafe=True)
96
        self.assertTrue(hasattr(model, "parametrizations"))
97
        self.assertTrue(parametrize.is_parametrized(model))
98
        self.assertTrue(parametrize.is_parametrized(model, "weight"))
99
        self.assertFalse(parametrize.is_parametrized(model, "bias"))
100
        self.assertNotIn("weight", model._parameters)
101
        self.assertTrue(model.weight.shape[0] == 1)
102
        parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
103
        self.assertFalse(hasattr(model, "parametrizations"))
104
        self.assertEqual(model.weight, initial_model.weight)
105
        self.assertEqual(id(model.weight), initial_weight_id)
106
        self.assertEqual(model.__class__, nn.Linear)
107

108
        # Two parametrizations with unsafe=True
109
        parametrize.register_parametrization(model, "weight", Resize(), unsafe=True)
110
        parametrize.register_parametrization(model, "weight", NoResize(), unsafe=False)
111
        self.assertTrue(hasattr(model, "parametrizations"))
112
        self.assertTrue(parametrize.is_parametrized(model))
113
        self.assertTrue(parametrize.is_parametrized(model, "weight"))
114
        self.assertFalse(parametrize.is_parametrized(model, "bias"))
115
        self.assertNotIn("weight", model._parameters)
116
        self.assertTrue(model.weight.shape[0] == 1)
117
        parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
118
        self.assertFalse(hasattr(model, "parametrizations"))
119
        self.assertEqual(model.weight, initial_model.weight)
120
        self.assertEqual(id(model.weight), initial_weight_id)
121
        self.assertEqual(model.__class__, nn.Linear)
122

123
        # Test unsafe flag doesn't change expected behavior
124
        parametrize.register_parametrization(model, "weight", Skew(), unsafe=True)
125
        self.assertTrue(hasattr(model, "parametrizations"))
126
        self.assertTrue(parametrize.is_parametrized(model))
127
        self.assertTrue(parametrize.is_parametrized(model, "weight"))
128
        self.assertFalse(parametrize.is_parametrized(model, "bias"))
129
        self.assertNotIn("weight", model._parameters)
130
        # Result should be skew-symmetric
131
        A = model.weight
132
        self.assertEqual(A, -A.T)
133
        if get_swap_module_params_on_conversion():
134
            # When using the swap_tensors path, this is needed so that the autograd
135
            # graph is not alive anymore.
136
            del A
137
        # Remove and check consistency
138
        parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
139
        self.assertFalse(hasattr(model, "parametrizations"))
140
        self.assertEqual(model.weight, initial_model.weight)
141
        self.assertEqual(id(model.weight), initial_weight_id)
142
        self.assertEqual(model.__class__, nn.Linear)
143

144
        # Test one parametrization
145
        parametrize.register_parametrization(model, "weight", Skew())
146
        self.assertTrue(hasattr(model, "parametrizations"))
147
        self.assertTrue(parametrize.is_parametrized(model))
148
        self.assertTrue(parametrize.is_parametrized(model, "weight"))
149
        self.assertFalse(parametrize.is_parametrized(model, "bias"))
150
        self.assertNotIn("weight", model._parameters)
151
        # Result should be skew-symmetric
152
        A = model.weight
153
        self.assertEqual(A, -A.T)
154
        if get_swap_module_params_on_conversion():
155
            # When using the swap_tensors path, this is needed so that the autograd
156
            # graph is not alive anymore.
157
            del A
158
        # Remove and check consistency
159
        parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
160
        self.assertFalse(hasattr(model, "parametrizations"))
161
        self.assertEqual(model.weight, initial_model.weight)
162
        self.assertEqual(id(model.weight), initial_weight_id)
163
        self.assertEqual(model.__class__, nn.Linear)
164

165
        # Test two parametrizations at the same time and removing them
166
        parametrize.register_parametrization(model, "weight", Skew())
167
        parametrize.register_parametrization(model, "weight", Orthogonal())
168
        # Result should be orthogonal
169
        X = model.weight
170
        Id = torch.eye(X.size(0), device=X.device)
171
        self.assertEqual(X.T @ X, Id)
172
        if get_swap_module_params_on_conversion():
173
            # When using the swap_tensors path, this is needed so that the autograd
174
            # graph is not alive anymore.
175
            del X
176
        # Structure tests
177
        self.assertTrue(hasattr(model, "parametrizations"))
178
        self.assertTrue(parametrize.is_parametrized(model))
179
        self.assertTrue(parametrize.is_parametrized(model, "weight"))
180
        self.assertFalse(parametrize.is_parametrized(model, "bias"))
181
        self.assertIn("weight", model.parametrizations)
182
        self.assertNotIn("weight", model._parameters)
183
        # Remove
184
        parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
185
        self.assertEqual(model.weight, initial_model.weight)
186
        self.assertEqual(id(model.weight), initial_weight_id)
187
        self.assertFalse(hasattr(model, "parametrizations"))
188
        self.assertEqual(model.__class__, nn.Linear)
189

190
        # Add everything
191
        parametrize.register_parametrization(model, "weight", Skew())
192
        parametrize.register_parametrization(model, "weight", Orthogonal())
193
        parametrize.register_parametrization(model, "bias", FirstZero())
194
        parametrize.register_parametrization(model, "bias", LastZero())
195

196
        # Basic tests
197
        self.assertTrue(parametrize.is_parametrized(model))
198
        self.assertTrue(parametrize.is_parametrized(model, "weight"))
199
        self.assertTrue(parametrize.is_parametrized(model, "bias"))
200
        self.assertEqual(model.bias[0].item(), 0.0)
201
        self.assertEqual(model.bias[-1].item(), 0.0)
202
        self.assertEqual(
203
            len(list(model.parameters())), 2
204
        )  # Nothing weird has happpened
205
        # Should not throw
206

207
        sgd = torch.optim.SGD(model.parameters(), lr=0.01)
208

209
        weight_copy = model.weight.clone()
210
        bias_copy = model.bias.clone()
211
        sgd.zero_grad()
212
        (model.weight.T @ model.bias).sum().backward()
213
        sgd.step()
214
        self.assertNotEqual(model.weight, weight_copy)
215
        self.assertNotEqual(model.bias, bias_copy)
216

217
        # Remove first parametrization.
218
        # Check that the model is still parametrized and so is the second parameter
219
        parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
220
        self.assertTrue(parametrize.is_parametrized(model))  # Still parametrized
221
        self.assertFalse(
222
            parametrize.is_parametrized(model, "weight")
223
        )  # Parametrization removed
224
        self.assertTrue(
225
            parametrize.is_parametrized(model, "bias")
226
        )  # Still parametrized
227
        self.assertEqual(model.bias[0].item(), 0.0)  # Still parametrized
228
        self.assertEqual(model.bias[-1].item(), 0.0)  # Still parametrized
229
        self.assertNotEqual(model.weight, initial_model.weight)  # Has been updated
230
        self.assertEqual(id(model.weight), initial_weight_id)  # Keeps the same id
231
        self.assertEqual(len(list(model.parameters())), 2)  # Nothing weird has happened
232
        # Should not throw
233
        weight_copy = model.weight.clone()
234
        bias_copy = model.bias.clone()
235
        sgd.zero_grad()
236
        (model.weight.T @ model.bias).sum().backward()
237
        sgd.step()
238
        self.assertNotEqual(model.weight, weight_copy)
239
        self.assertNotEqual(model.bias, bias_copy)
240

241
        # Remove the second parametrization.
242
        # Check that the module is not parametrized
243
        parametrize.remove_parametrizations(model, "bias", leave_parametrized=False)
244
        self.assertFalse(parametrize.is_parametrized(model))  # Not parametrized
245
        self.assertNotEqual(model.bias, initial_model.bias)  # Has been updated
246
        self.assertNotEqual(model.bias[0].item(), 0.0)  # Not parametrized
247
        self.assertNotEqual(model.bias[-1].item(), 0.0)  # Not parametrized
248
        self.assertEqual(id(model.bias), initial_bias_id)  # Keeps the same id
249
        self.assertFalse(
250
            hasattr(model, "parametrizations")
251
        )  # Not parametrized the module
252
        self.assertEqual(model.__class__, nn.Linear)  # Resores the previous class
253
        self.assertEqual(len(list(model.parameters())), 2)  # Nothing weird has happeed
254

255
        # Should not throw things are updated
256
        weight_copy = model.weight.clone()
257
        bias_copy = model.bias.clone()
258
        sgd.zero_grad()
259
        (model.weight.T @ model.bias).sum().backward()
260
        sgd.step()
261
        self.assertNotEqual(model.weight, weight_copy)
262
        self.assertNotEqual(model.bias, bias_copy)
263
        if get_swap_module_params_on_conversion():
264
            # When using the swap_tensors path, this is needed so that the autograd
265
            # graph is not alive anymore.
266
            del weight_copy, bias_copy
267

268
        # Test leave_parametrized=True
269
        for _ in range(2):
270
            parametrize.register_parametrization(model, "weight", Skew())
271
            parametrize.register_parametrization(model, "weight", Orthogonal())
272
            parametrize.remove_parametrizations(
273
                model, "weight", leave_parametrized=True
274
            )
275
            # We didn't change the dtype nor had multiple inputs, so the id should be the same
276
            self.assertEqual(id(model.weight), initial_weight_id)
277
            self.assertEqual(id(model.bias), initial_bias_id)
278

279
            # Should not throw. Things are updated
280
            weight_copy = model.weight.clone()
281
            bias_copy = model.bias.clone()
282
            sgd.zero_grad()
283
            (model.weight.T @ model.bias).sum().backward()
284
            sgd.step()
285
            self.assertNotEqual(model.weight, weight_copy)
286
            self.assertNotEqual(model.bias, bias_copy)
287
            if get_swap_module_params_on_conversion():
288
                # When using the swap_tensors path, this is needed so that the autograd
289
                # graph is not alive anymore.
290
                del weight_copy, bias_copy
291

292
    @swap([True, False])
293
    def test_register_and_remove_nested_parametrization(self):
294
        r"""Test that it is possible to nest the parametrizations
295
        meaning that the original param is parametrized again
296
        """
297

298
        class Skew(nn.Module):
299
            def forward(self, X):
300
                X = X.tril(-1)
301
                return X - X.T
302

303
        model = nn.Linear(8, 8)
304
        # Add top level parametrization
305
        parametrize.register_parametrization(model, "weight", Skew())
306
        self.assertTrue(hasattr(model, "parametrizations"))
307
        self.assertTrue(parametrize.is_parametrized(model))
308
        self.assertTrue(parametrize.is_parametrized(model, "weight"))
309
        self.assertFalse(parametrize.is_parametrized(model, "bias"))
310
        self.assertNotIn("weight", model._parameters)
311
        # Result should be skew-symmetric
312
        A = model.weight
313
        self.assertEqual(A, -A.T)
314
        if get_swap_module_params_on_conversion():
315
            # When using the swap_tensors path, this is needed so that the autograd
316
            # graph is not alive anymore.
317
            del A
318

319
        # Add nested parametrization
320
        param_mod = model.parametrizations.weight
321
        self.assertFalse(hasattr(param_mod, "parametrizations"))
322
        self.assertFalse(parametrize.is_parametrized(param_mod))
323
        self.assertFalse(parametrize.is_parametrized(param_mod, "original"))
324

325
        parametrize.register_parametrization(param_mod, "original", Skew())
326
        self.assertTrue(hasattr(param_mod, "parametrizations"))
327
        self.assertTrue(parametrize.is_parametrized(param_mod))
328
        self.assertTrue(parametrize.is_parametrized(param_mod, "original"))
329
        self.assertNotIn("original", param_mod._parameters)
330
        # Result should be skew-symmetric
331
        A = param_mod.original
332
        self.assertEqual(A, -A.T)
333

334
        # Remove nested param and check consistency
335
        parametrize.remove_parametrizations(
336
            param_mod, "original", leave_parametrized=False
337
        )
338
        self.assertFalse(hasattr(param_mod, "parametrizations"))
339
        self.assertEqual(param_mod.__class__, parametrize.ParametrizationList)
340

341
        # Remove top level and check consistency
342
        parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
343
        self.assertFalse(hasattr(model, "parametrizations"))
344
        self.assertEqual(model.__class__, nn.Linear)
345

346
    @swap([True, False])
347
    def test_register_and_remove_buffer_parametrization(self):
348
        r"""Test that it is possible to add and remove parametrizations on buffers"""
349

350
        # Define a couple vector parametrizations
351
        class FirstZero(nn.Module):
352
            def forward(self, x):
353
                return torch.cat([x.new_zeros(1), x[1:]])
354

355
        class LastZero(nn.Module):
356
            def forward(self, x):
357
                return torch.cat([x[:-1], x.new_zeros(1)])
358

359
        model = nn.Linear(8, 8)
360

361
        # Instantiate parametrizations on buffers. It should work as expected
362
        delattr(model, "bias")
363
        model.bias = Buffer(torch.ones(8))
364
        parametrize.register_parametrization(model, "bias", FirstZero())
365
        parametrize.register_parametrization(model, "bias", LastZero())
366
        self.assertTrue(parametrize.is_parametrized(model))
367
        self.assertTrue(parametrize.is_parametrized(model, "bias"))
368
        self.assertEqual(model.bias[0].item(), 0.0)
369
        self.assertEqual(model.bias[-1].item(), 0.0)
370
        self.assertTrue((model.bias[1:-1] == torch.ones(6)).all())
371
        self.assertEqual(len(list(model.parameters())), 1)
372

373
        # Remove parametrizations on buffers. It should work as expected
374
        parametrize.remove_parametrizations(model, "bias", leave_parametrized=True)
375
        self.assertFalse(parametrize.is_parametrized(model))
376
        self.assertFalse(parametrize.is_parametrized(model, "bias"))
377
        self.assertEqual(model.bias[0].item(), 0.0)
378
        self.assertEqual(model.bias[-1].item(), 0.0)
379
        self.assertTrue((model.bias[1:-1] == torch.ones(6)).all())
380
        self.assertEqual(len(list(model.parameters())), 1)
381

382
    # FIXME: Rewrite this test using functions not depending on LAPACK
383
    #        and remove the `@skipIfNoLapack` (see #70995)
384
    @skipIfNoLapack
385
    @swap([True, False])
386
    def test_serialization_parametrization(self):
387
        r"""Test that it is possible to serialize a parametrized model via state_dict"""
388

389
        # A stateful parametrization
390
        class Orthogonal(nn.Module):
391
            def __init__(self, n):
392
                super().__init__()
393
                self.id = Buffer(torch.eye(n))
394
                self.B = Buffer(torch.empty(n, n))
395
                init.orthogonal_(self.B)
396

397
            def forward(self, X):
398
                A = X.triu(1)
399
                A = A - A.T
400
                return self.B @ torch.linalg.solve(self.id + A, self.id - A)
401

402
        def get_model():
403
            model = torch.nn.Sequential(
404
                torch.nn.Linear(5, 5),
405
                torch.nn.ReLU(),
406
                torch.nn.Linear(5, 1),
407
            )
408

409
            parametrize.register_parametrization(model[0], "weight", Orthogonal(5))
410
            return model
411

412
        model = get_model()
413

414
        prev_weight = model[0].weight
415
        prev_B = model[0].parametrizations.weight[0].B
416

417
        new_model = get_model()
418
        with TemporaryFileName() as fname:
419
            torch.save(model.state_dict(), fname)
420
            new_model.load_state_dict(torch.load(fname))
421

422
        # Integrity tests
423
        self.assertTrue(parametrize.is_parametrized(new_model[0], "weight"))
424
        self.assertEqual(prev_weight, new_model[0].weight)
425
        self.assertEqual(prev_B, new_model[0].parametrizations.weight[0].B)
426

427
        # Trying to save the whole parametrized model raises
428
        with self.assertRaisesRegex(RuntimeError, "state_dict"):
429
            with TemporaryFileName() as fname:
430
                torch.save(model, fname)
431

432
    # FIXME: Rewrite this test using functions not depending on LAPACK
433
    #        and remove the `@skipIfNoLapack` (see #70995)
434
    @skipIfNoLapack
435
    @swap([True, False])
436
    def test_initialization_parametrization(self):
437
        r"""Test that it is possible to initialize a parametrization when it
438
        implements a `right_inverse` method
439
        """
440

441
        class Skew(nn.Module):
442
            def forward(self, X):
443
                A = X.triu(1)
444
                return A - A.T
445

446
            def is_skew(self, A):
447
                return torch.allclose(A, -A.T, atol=1e-6)
448

449
            def right_inverse(self, X):
450
                if not self.is_skew(X):
451
                    raise ValueError("The matrix is not skew-symmetric.")
452
                return X.triu(1)
453

454
        # Implements a Cayley map where right_inverse is not quite the inverse of forward
455
        class Orthogonal(nn.Module):
456
            def __init__(self, n):
457
                super().__init__()
458
                self.B = Buffer(torch.eye(n))
459

460
            def forward(self, X):
461
                Id = torch.eye(X.size(0))
462
                return self.B @ torch.linalg.solve(Id + X, Id - X)
463

464
            def is_orthogonal(self, X):
465
                Id = torch.eye(X.size(0))
466
                return torch.allclose(X.T @ X, Id, atol=1e-4)
467

468
            def right_inverse(self, X):
469
                if not self.is_orthogonal(X):
470
                    raise ValueError("The input is not orthogonal.")
471
                # cayley(0) == Id, so B @ cayley(0) == B
472
                self.B = X
473
                return torch.zeros_like(X)
474

475
        N = 5
476
        model = nn.Linear(N, N)
477
        # Register the skew-symmetric constraint. The result is now skew-symmetric
478
        skew = Skew()
479
        # Make the weight skew-symmetric before registering the parametrization
480
        with torch.no_grad():
481
            model.weight.set_(skew(model.weight))
482
        parametrize.register_parametrization(model, "weight", skew)
483
        X = torch.rand(N, N)
484
        # X is not skew-symmetric, so it throws an error
485
        with self.assertRaises(ValueError):
486
            model.weight = X
487
        # Make X skew-symmetric
488
        X = X - X.T
489
        model.weight = X
490
        self.assertEqual(model.parametrizations.weight.original, X.triu(1))
491
        self.assertEqual(model.weight, X)
492

493
        # Having several parametrizations registered should work in the same way
494
        parametrize.register_parametrization(model, "weight", Orthogonal(N))
495
        # Register now the Cayley map. The result is now orthogonal
496
        X = torch.rand(N, N)
497
        # X is not orthogonal, so it throws an error
498
        with self.assertRaises(ValueError):
499
            model.weight = X
500
        init.orthogonal_(X)
501
        model.weight = X
502
        self.assertEqual(model.weight, X)
503
        self.assertEqual(model.parametrizations.weight.original, torch.zeros_like(X))
504

505
    @swap([True, False])
506
    def test_errors_unparametrized_tensor_parametrization(self):
507
        # Test errors when registering a parametrization on an unparametrized tensor
508
        module = nn.Linear(3, 4)
509
        weight_init = module.weight.clone()
510

511
        class Identity(nn.Module):
512
            def forward(self, x):
513
                return x
514

515
        # Register a parametrization on a non-existing parameter throws
516
        with self.assertRaisesRegex(ValueError, "does not have a parameter"):
517
            parametrize.register_parametrization(module, "foo", Identity())
518
        self.assertFalse(parametrize.is_parametrized(module))
519

520
        # Removing parametrizations from an unparametrized tensor throws
521
        with self.assertRaisesRegex(ValueError, "does not have a parametrization"):
522
            parametrize.remove_parametrizations(module, "bias")
523
        self.assertFalse(parametrize.is_parametrized(module))
524

525
        # A correct parametrization with several outputs
526
        class Sum(nn.Module):
527
            def forward(self, x, y):
528
                return x + y
529

530
            def right_inverse(self, z):
531
                return z, torch.zeros_like(z)
532

533
        parametrize.register_parametrization(module, "weight", Sum())
534
        # Cannot remove a parametrization with several outputs with `leave_parametrized=False`
535
        with self.assertRaisesRegex(ValueError, "leave_parametrized=False"):
536
            parametrize.remove_parametrizations(
537
                module, "weight", leave_parametrized=False
538
            )
539
        parametrize.remove_parametrizations(module, "weight", leave_parametrized=True)
540

541
        # A parametrization with an incorrect number of outputs
542
        class WrongNumberParams(nn.Module):
543
            def forward(self, x, y, z):
544
                return x + y + z
545

546
            def right_inverse(self, w):
547
                return w, torch.zeros_like(w)
548

549
        # Makes param(*param.right_inverse(X)) fail
550
        with self.assertRaisesRegex(TypeError, "positional argument"):
551
            parametrize.register_parametrization(module, "weight", WrongNumberParams())
552
        self.assertFalse(parametrize.is_parametrized(module))
553

554
        # A parametrization with a right_inverse that does not return a Tensor or Sequence[Tensor]
555
        class WrongRightInverse(Identity):
556
            def right_inverse(self, z):
557
                return None
558

559
        # right_inverse should return a Tensor or a Sequence[Tensor]
560
        with self.assertRaisesRegex(ValueError, "Tensor or a Sequence of"):
561
            parametrize.register_parametrization(module, "weight", WrongRightInverse())
562
        self.assertFalse(parametrize.is_parametrized(module))
563

564
        # If it's a sequence, it must to be a sequence of tensors
565
        class WrongRightInverseSequence(nn.Module):
566
            def forward(self, x, y):
567
                return x
568

569
            def right_inverse(self, z):
570
                return None, z
571

572
        with self.assertRaisesRegex(ValueError, "of the sequence with type"):
573
            parametrize.register_parametrization(
574
                module, "weight", WrongRightInverseSequence()
575
            )
576
        self.assertFalse(parametrize.is_parametrized(module))
577

578
        # A parametrization from one tensor to one tensor that changes the dtype
579
        class ChangeDtypeInverse(nn.Module):
580
            def forward(self, x):
581
                return x.float()
582

583
            def right_inverse(self, w):
584
                return w.bool()
585

586
        # For parametrizations that return one tensor, right_inverse may not change the dtype
587
        with self.assertRaisesRegex(
588
            ValueError, "outputs one tensor, it may not change the dtype"
589
        ):
590
            parametrize.register_parametrization(module, "weight", ChangeDtypeInverse())
591
        self.assertFalse(parametrize.is_parametrized(module))
592

593
        # Doesn't return a tensor
594
        class NotTensor(nn.Module):
595
            def forward(self, x):
596
                return 2
597

598
        # Forward must return a tensor
599
        with self.assertRaisesRegex(ValueError, "must return a tensor"):
600
            parametrize.register_parametrization(module, "weight", NotTensor())
601
        self.assertFalse(parametrize.is_parametrized(module))
602

603
        # A parametrization from one tensor to one tensor that changes the dtype
604
        class ChangeDtype(nn.Module):
605
            def forward(self, x):
606
                return x.bool()
607

608
        # forward should not change the initial dtype
609
        with self.assertRaisesRegex(ValueError, "may not change the dtype"):
610
            parametrize.register_parametrization(module, "weight", ChangeDtype())
611
        self.assertFalse(parametrize.is_parametrized(module))
612

613
        # Change shape
614
        class ChangeShape(nn.Module):
615
            def forward(self, x):
616
                return x[:-1]
617

618
        # forward should not change the original shape
619
        with self.assertRaisesRegex(ValueError, "may not change the shape"):
620
            parametrize.register_parametrization(module, "weight", ChangeShape())
621
        self.assertFalse(parametrize.is_parametrized(module))
622

623
        # Many to one that changes dtype
624
        class ChangeDtypeMulti(nn.Module):
625
            def forward(self, x, y):
626
                return (x + y).bool()
627

628
            def right_inverse(self, w):
629
                return w, w + 1
630

631
        # forward should not change the original shape even for parametrizations with many inputs
632
        with self.assertRaisesRegex(ValueError, "may not change the dtype"):
633
            parametrize.register_parametrization(module, "weight", ChangeDtypeMulti())
634
        self.assertFalse(parametrize.is_parametrized(module))
635

636
        # Returning a sequence of size one, although weird, it's correct
637
        class SequenceLen1(nn.Module):
638
            def forward(self, x):
639
                return x
640

641
            def right_inverse(self, w):
642
                return (w,)
643

644
        parametrize.register_parametrization(module, "weight", SequenceLen1())
645
        self.assertTrue(hasattr(module.parametrizations.weight, "original0"))
646
        self.assertFalse(hasattr(module.parametrizations.weight, "original1"))
647
        _ = module.weight  # Does not throw
648
        self.assertTrue(parametrize.is_parametrized(module))
649
        parametrize.remove_parametrizations(module, "weight", leave_parametrized=True)
650

651
        # None of the operations above should have altered the weight
652
        self.assertFalse(parametrize.is_parametrized(module))
653
        self.assertEqual(module.weight, weight_init)
654

655
    @swap([True, False])
656
    def test_errors_parametrized_tensor_parametrization(self):
657
        # Test errors when registering a parametrization on a parametrized tensor
658

659
        class Identity(nn.Module):
660
            def forward(self, x):
661
                return x
662

663
        module = nn.Linear(3, 4)
664
        parametrize.register_parametrization(module, "weight", Identity())
665

666
        # Has to return a tensor
667
        class WrongReturn(nn.Module):
668
            def forward(self, x):
669
                return x, x
670

671
        with self.assertRaisesRegex(ValueError, "must return a tensor"):
672
            parametrize.register_parametrization(module, "weight", WrongReturn())
673
        self.assertTrue(parametrize.is_parametrized(module))
674
        self.assertEqual(len(module.parametrizations.weight), 1)
675
        self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))
676

677
        # Cannot change dtype
678
        class ChangeDtype(nn.Module):
679
            def forward(self, x):
680
                return x.bool()
681

682
        with self.assertRaisesRegex(ValueError, "may not change the dtype"):
683
            parametrize.register_parametrization(module, "weight", ChangeDtype())
684
        self.assertTrue(parametrize.is_parametrized(module))
685
        self.assertEqual(len(module.parametrizations.weight), 1)
686
        self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))
687

688
        # Cannot change shape
689
        class ChangeShape(nn.Module):
690
            def forward(self, x):
691
                return x[:-1]
692

693
        with self.assertRaisesRegex(ValueError, "may not change the shape"):
694
            parametrize.register_parametrization(module, "weight", ChangeShape())
695
        self.assertTrue(parametrize.is_parametrized(module))
696
        self.assertEqual(len(module.parametrizations.weight), 1)
697
        self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))
698

699
        # The following checks are mostly due to bugs in the code of the parametrization
700

701
        # right_inverse has to return a tensor
702
        class WrongReturnInverse(Identity):
703
            def right_inverse(self, x):
704
                return x, x
705

706
        with self.assertRaisesRegex(ValueError, "right_inverse must return a tensor"):
707
            parametrize.register_parametrization(module, "weight", WrongReturnInverse())
708
        self.assertTrue(parametrize.is_parametrized(module))
709
        self.assertEqual(len(module.parametrizations.weight), 1)
710
        self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))
711

712
        # Cannot change dtype
713
        class ChangeDtypeInverse(Identity):
714
            def right_inverse(self, x):
715
                return x.bool()
716

717
        with self.assertRaisesRegex(ValueError, "must have the same dtype"):
718
            parametrize.register_parametrization(module, "weight", ChangeDtypeInverse())
719
        self.assertTrue(parametrize.is_parametrized(module))
720
        self.assertEqual(len(module.parametrizations.weight), 1)
721
        self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))
722

723
        # Cannot change shape
724
        class ChangeShapeInverse(Identity):
725
            def right_inverse(self, x):
726
                return x[:-1]
727

728
        with self.assertRaisesRegex(ValueError, "must have the same shape"):
729
            parametrize.register_parametrization(module, "weight", ChangeShapeInverse())
730
        self.assertTrue(parametrize.is_parametrized(module))
731
        self.assertEqual(len(module.parametrizations.weight), 1)
732
        self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))
733

734
    # FIXME: Rewrite this test using functions not depending on LAPACK
735
    #        and remove the `@skipIfNoLapack` (see #70995)
736
    @skipIfNoLapack
737
    @swap([True, False])
738
    def test_multiple_inputs_parametrization(self):
739
        # A parametrization with several outputs
740
        class RankOne(nn.Module):
741
            def forward(self, x, y):
742
                # Form a rank-1 matrix from a pair of vectors
743
                return x.unsqueeze(-1) @ y.unsqueeze(-2)
744

745
            def right_inverse(self, Y):
746
                # We project the given matrix onto the rank 1 matrices
747
                U, S, Vh = torch.linalg.svd(Y, full_matrices=False)
748
                # S is ordered in a decreasing way.
749
                s0_sqrt = S[0].sqrt().unsqueeze(-1)
750
                return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt
751

752
        # Simple parametrisation
753
        class Double(nn.Module):
754
            def forward(self, x):
755
                return 2.0 * x
756

757
            def right_inverse(self, w):
758
                return 0.5 * w
759

760
        model = nn.Linear(3, 3)
761
        # Test one parametrization
762
        parametrize.register_parametrization(model, "weight", RankOne())
763
        self.assertTrue(hasattr(model, "parametrizations"))
764
        self.assertTrue(parametrize.is_parametrized(model))
765
        self.assertTrue(parametrize.is_parametrized(model, "weight"))
766
        self.assertTrue(hasattr(model.parametrizations.weight, "original0"))
767
        self.assertIn("original0", model.parametrizations.weight._parameters)
768
        self.assertTrue(hasattr(model.parametrizations.weight, "original1"))
769
        self.assertIn("original1", model.parametrizations.weight._parameters)
770
        self.assertFalse(parametrize.is_parametrized(model, "bias"))
771
        self.assertNotIn("weight", model._parameters)
772
        # Result should be rank 1
773
        self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1)
774

775
        with self.assertRaisesRegex(ValueError, "leave_parametrized=False"):
776
            # Cannot remove a parametrization with multiple inputs and not leave it parametrized
777
            parametrize.remove_parametrizations(
778
                model, "weight", leave_parametrized=False
779
            )
780
        # Remove parametrization and check consistency
781
        parametrize.remove_parametrizations(model, "weight", leave_parametrized=True)
782
        self.assertFalse(hasattr(model, "parametrizations"))
783
        self.assertEqual(model.__class__, nn.Linear)
784
        self.assertFalse(parametrize.is_parametrized(model))
785
        self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1)
786
        self.assertIn("weight", model._parameters)
787

788
        # Registering parametrizations with one input on top of one with multiple inputs should work
789
        init_weight = model.weight.clone()
790
        parametrize.register_parametrization(model, "weight", RankOne())
791
        # Projecting a rank 1 matrix onto the matrices of rank one does not change the matrix
792
        self.assertEqual(init_weight, model.weight)
793
        parametrize.register_parametrization(model, "weight", Double())
794
        # The matrix now is twice the initial matrix
795
        self.assertEqual(2.0 * init_weight, model.weight)
796
        # Multiplying by a scalar does not change the rank
797
        self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1)
798

799
        # The model has now three parameters
800
        self.assertEqual(len(list(model.parameters())), 3)
801

802
        sgd = torch.optim.SGD(model.parameters(), lr=0.1)
803

804
        # Test backward. Should not throw
805
        for _ in range(2):
806
            sgd.zero_grad()
807
            loss = (model.weight.T @ model.bias).sum()
808
            loss.backward()
809
            sgd.step()
810

811
        # Same drill as before, removing should work as expected
812
        with self.assertRaisesRegex(ValueError, "leave_parametrized=False"):
813
            # Cannot remove a parametrization with multiple inputs and not leave it parametrized
814
            parametrize.remove_parametrizations(
815
                model, "weight", leave_parametrized=False
816
            )
817
        # Remove parametrization and check consistency
818
        parametrize.remove_parametrizations(model, "weight", leave_parametrized=True)
819
        self.assertFalse(hasattr(model, "parametrizations"))
820
        self.assertEqual(model.__class__, nn.Linear)
821
        self.assertFalse(parametrize.is_parametrized(model))
822
        self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1)
823
        self.assertIn("weight", model._parameters)
824

825
        # The model has now two parameters
826
        self.assertEqual(len(list(model.parameters())), 2)
827

828
        # Test backward. Should not throw
829
        sgd = torch.optim.SGD(model.parameters(), lr=0.1)
830
        for _ in range(2):
831
            sgd.zero_grad()
832
            loss = (model.weight.T @ model.bias).sum()
833
            loss.backward()
834
            sgd.step()
835

836
    # FIXME: Rewrite this test using functions not depending on LAPACK
837
    #        and remove the `@skipIfNoLapack` (see #70995)
838
    @skipIfNoLapack
839
    @swap([True, False])
840
    def test_caching_parametrization(self):
841
        r"""Test the caching system of a parametrization"""
842

843
        # Define a couple matrix parametrizations
844
        class Skew(nn.Module):
845
            def forward(self, X):
846
                X = X.tril(-1)
847
                return X - X.T
848

849
        class Orthogonal(nn.Module):
850
            def forward(self, X):
851
                Id = torch.eye(X.size(0), device=X.device)
852
                return torch.linalg.solve(Id + X, Id - X)
853

854
        model = nn.Linear(5, 5)
855
        parametrize.register_parametrization(model, "weight", Skew())
856
        parametrize.register_parametrization(model, "weight", Orthogonal())
857

858
        # Test that the caching system works
859
        with parametrize.cached():
860
            X = model.weight
861
            Y = model.weight
862
            self.assertEqual(id(X), id(Y))
863

864
    # FIXME: Rewrite this test using functions not depending on LAPACK
865
    #        and remove the `@skipIfNoLapack` (see #70995)
866
    @skipIfNoLapack
867
    @swap([True, False])
868
    def test_caching_parametrization_with_transfer_parametrizations_and_params(self):
869
        r"""Test that transferring parametrizations doesn't cause issues with caching"""
870

871
        class Skew(nn.Module):
872
            def forward(self, X):
873
                X = X.tril(-1)
874
                return X - X.T
875

876
        class Orthogonal(nn.Module):
877
            def forward(self, X):
878
                Id = torch.eye(X.size(0), device=X.device)
879
                return torch.linalg.solve(Id + X, Id - X)
880

881
        model = nn.Linear(5, 5)
882
        parametrize.register_parametrization(model, "weight", Skew())
883
        parametrize.register_parametrization(model, "weight", Orthogonal())
884

885
        to_model = nn.Linear(5, 5)
886
        parametrize.transfer_parametrizations_and_params(model, to_model)
887

888
        with parametrize.cached():
889
            X = model.weight
890
            Y = model.weight
891
            self.assertEqual(id(X), id(Y))
892

893
            A = to_model.weight
894
            B = to_model.weight
895
            self.assertEqual(id(A), id(B))
896

897
            # test that the results are distinct objects for each module
898
            self.assertNotEqual(id(A), id(X))
899

900
    @swap([True, False])
901
    def test_parametrization_same_training_mode(self):
902
        r"""Test training mode updated on parametrization registration"""
903

904
        class Identity(nn.Module):
905
            def forward(self, X):
906
                return X
907

908
        module = nn.Linear(4, 4)
909
        module.eval()
910
        parametrize.register_parametrization(module, "weight", Identity())
911
        self.assertFalse(module.parametrizations.weight[0].training)
912
        module.train()
913
        parametrize.register_parametrization(module, "weight", Identity().eval())
914
        self.assertTrue(module.parametrizations.weight[0].training)
915
        self.assertTrue(module.parametrizations.weight[1].training)
916

917
    @swap([True, False])
918
    def test_type_before_parametrizations(self):
919
        r"""Test that type_before_parametrizations always retrieves original type"""
920

921
        class Identity(nn.Module):
922
            def forward(self, X):
923
                return X
924

925
        model = nn.Linear(5, 5)
926
        original_type = type(model)
927
        self.assertTrue(
928
            parametrize.type_before_parametrizations(model) == original_type
929
        )
930
        parametrize.register_parametrization(model, "weight", Identity())
931
        self.assertTrue(
932
            parametrize.type_before_parametrizations(model) == original_type
933
        )
934

935
    @swap([True, False])
936
    def test_deepcopy_after_parametrization(self):
937
        r"""Test that we are able to create a deepcopy of the module when it's parametrized."""
938

939
        class AddOne(nn.Module):
940
            def forward(self, x):
941
                return x + 1.0
942

943
        class ModelWithoutDeepcopy(nn.Module):
944
            def __init__(self) -> None:
945
                super().__init__()
946
                self.weight = nn.Parameter(
947
                    torch.tensor([1.0, 1.0, 1.0, 1.0]), requires_grad=True
948
                )
949
                self.bias = nn.Parameter(
950
                    torch.tensor([0.0, 0.0, 0.0, 0.0]), requires_grad=True
951
                )
952
                self.attr = [1.0, 2.0, 3.0, 4.0]
953

954
        class ActualModel(ModelWithoutDeepcopy):
955
            # Emulate custom implementation of the deepcopying.
956
            def __deepcopy__(self, memo):
957
                result = self.__new__(self.__class__)
958
                memo[id(self)] = result
959
                result.__dict__ = deepcopy(self.__dict__, memo)
960
                return result
961

962
        def check_deepcopy(m1: nn.Module, m2: nn.Module):
963
            w1 = m1.parametrizations.weight.original
964
            w2 = m2.parametrizations.weight.original
965
            b1 = (
966
                m1.parametrizations.bias.original
967
                if parametrize.is_parametrized(m1, "bias")
968
                else m1.bias
969
            )
970
            b2 = (
971
                m2.parametrizations.bias.original
972
                if parametrize.is_parametrized(m2, "bias")
973
                else m2.bias
974
            )
975
            # Weights, biases and attributes should be equal but they must be different objects.
976
            self.assertEqual(m1.__dict__.keys(), m2.__dict__.keys())
977
            self.assertIsNot(m1, m2)
978
            self.assertEqual(w1, w2)
979
            self.assertIsNot(w1, w2)
980
            self.assertEqual(b1, b2)
981
            self.assertIsNot(b1, b2)
982
            self.assertEqual(m1.attr, m2.attr)
983
            self.assertIsNot(m1.attr, m2.attr)
984

985
        for model in (ModelWithoutDeepcopy(), ActualModel()):
986
            # General check that we are able to create deepcopy.
987
            parametrize.register_parametrization(model, "weight", AddOne())
988
            check_deepcopy(model, deepcopy(model))
989
            # Check that this works on models with several parametrized tensors.
990
            parametrize.register_parametrization(model, "bias", AddOne())
991
            check_deepcopy(model, deepcopy(model))
992
            # Check that this works on models where tensors have more than one parametrization.
993
            parametrize.register_parametrization(model, "weight", AddOne())
994
            check_deepcopy(model, deepcopy(model))
995

996
    @swap([True, False])
997
    def test_transfer_parametrizations_and_params(self):
998
        r"""Test that all parametrizations and their associated parameters are transferred."""
999

1000
        class AddOne(nn.Module):
1001
            def forward(self, x):
1002
                return x + 1.0
1003

1004
        class Double(nn.Module):
1005
            def forward(self, x):
1006
                return 2.0 * x
1007

1008
            def right_inverse(self, x):
1009
                return 0.5 * x
1010

1011
        class MinusOne(nn.Module):
1012
            def forward(self, x):
1013
                return x - 1.0
1014

1015
        model = nn.Linear(5, 5)
1016
        parametrize.register_parametrization(model, "weight", AddOne())
1017
        parametrize.register_parametrization(model, "weight", Double())
1018
        parametrize.register_parametrization(model, "weight", MinusOne())
1019
        hold_weight = model.weight
1020

1021
        to_model = torch.ao.nn.qat.Linear(
1022
            5, 5, qconfig=torch.ao.quantization.get_default_qconfig()
1023
        )
1024
        parametrize.transfer_parametrizations_and_params(model, to_model)
1025

1026
        # checks that final and original value are correct and the to_model is parametrized
1027
        self.assertTrue(torch.nn.utils.parametrize.is_parametrized(to_model, "weight"))
1028
        self.assertEqual(model.weight, to_model.weight)
1029
        self.assertEqual(
1030
            model.parametrizations.weight.original,
1031
            to_model.parametrizations.weight.original,
1032
        )
1033

1034
        # check that the transfer didn't affect the original value
1035
        self.assertEqual(hold_weight, model.weight)
1036
        if get_swap_module_params_on_conversion():
1037
            # When using the swap_tensors path, this is needed so that the autograd
1038
            # graph is not alive anymore.
1039
            del hold_weight
1040

1041
        # testing that changes to one set of parametrizations do not affect the other
1042
        parametrize.remove_parametrizations(to_model, "weight")
1043
        self.assertFalse(torch.nn.utils.parametrize.is_parametrized(to_model, "weight"))
1044
        self.assertTrue(torch.nn.utils.parametrize.is_parametrized(model, "weight"))
1045

1046
        # also test that parameters that don't exist in to_model get transferred
1047
        model.test_param = Parameter(torch.randn(5, 5))
1048

1049
        self.assertTrue(not hasattr(to_model, "test_param"))
1050
        parametrize.register_parametrization(model, "test_param", Double())
1051
        hold_test_param = model.test_param
1052
        parametrize.transfer_parametrizations_and_params(model, to_model, "test_param")
1053

1054
        # check that previously missing params got transferred correctly
1055
        self.assertEqual(model.test_param, to_model.test_param)
1056
        self.assertEqual(
1057
            model.parametrizations.test_param.original,
1058
            to_model.parametrizations.test_param.original,
1059
        )
1060

1061
        # check that the new transfer didn't change the value for the from_module
1062
        self.assertEqual(hold_test_param, model.test_param)
1063

1064
    @swap([True, False])
1065
    def test_transfer_parametrizations_and_params_right_inverse(self):
1066
        r"""Test that all parametrizations and their associated parameters are transferred."""
1067

1068
        class Double(nn.Module):
1069
            def forward(self, x):
1070
                return 2.0 * x
1071

1072
            def right_inverse(self, x):
1073
                return 0.5 * x
1074

1075
        model = nn.Linear(5, 5)
1076
        parametrize.register_parametrization(model, "weight", Double())
1077
        hold_weight = model.weight
1078

1079
        to_model = torch.ao.nn.qat.Linear(
1080
            5, 5, qconfig=torch.ao.quantization.get_default_qconfig()
1081
        )
1082
        parametrize.transfer_parametrizations_and_params(model, to_model)
1083

1084
        # check that transfer occurs successfully
1085
        self.assertEqual(model.weight, to_model.weight)
1086
        self.assertEqual(
1087
            model.parametrizations.weight.original,
1088
            to_model.parametrizations.weight.original,
1089
        )
1090

1091
        # check that transfer doesn't affect the from_model weight
1092
        self.assertEqual(hold_weight, model.weight)
1093

1094
    @swap([True, False])
1095
    def test_transfer_parametrizations_and_params_single_param(self):
1096
        r"""Test that all parametrizations and their associated parameters are transferred."""
1097

1098
        class AddOne(nn.Module):
1099
            def forward(self, x):
1100
                return x + 1.0
1101

1102
        class Double(nn.Module):
1103
            def forward(self, x):
1104
                return 2.0 * x
1105

1106
        class MinusOne(nn.Module):
1107
            def forward(self, x):
1108
                return x - 1.0
1109

1110
        model = nn.Linear(5, 5, bias=True)
1111
        parametrize.register_parametrization(model, "weight", AddOne())
1112
        parametrize.register_parametrization(model, "weight", Double())
1113
        parametrize.register_parametrization(model, "weight", MinusOne())
1114
        parametrize.register_parametrization(model, "bias", AddOne())
1115
        parametrize.register_parametrization(model, "bias", Double())
1116
        parametrize.register_parametrization(model, "bias", MinusOne())
1117

1118
        to_model = torch.ao.nn.qat.Linear(
1119
            5, 5, bias=True, qconfig=torch.ao.quantization.get_default_qconfig()
1120
        )
1121
        parametrize.transfer_parametrizations_and_params(model, to_model, "weight")
1122

1123
        # check that weight and only weight was transferred
1124
        self.assertEqual(model.weight, to_model.weight)
1125
        self.assertEqual(
1126
            model.parametrizations.weight.original,
1127
            to_model.parametrizations.weight.original,
1128
        )
1129
        self.assertTrue("bias" not in to_model.parametrizations)
1130

1131
    # FIXME: Rewrite this test using functions not depending on LAPACK
1132
    # and remove the `@skipIfNoLapack` (see #70995)
1133
    @skipIfNoLapack
1134
    @swap([True, False])
1135
    def test_transfer_parametrizations_and_params_many_to_one(self):
1136
        # A parametrization with several outputs
1137
        class RankOne(nn.Module):
1138
            def forward(self, x, y):
1139
                # Form a rank-1 matrix from a pair of vectors
1140
                return x.unsqueeze(-1) @ y.unsqueeze(-2)
1141

1142
            def right_inverse(self, Y):
1143
                # We project the given matrix onto the rank 1 matrices
1144
                U, S, Vh = torch.linalg.svd(Y, full_matrices=False)
1145
                # S is ordered in a decreasing way.
1146
                s0_sqrt = S[0].sqrt().unsqueeze(-1)
1147
                return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt
1148

1149
        class Double(nn.Module):
1150
            def forward(self, x):
1151
                return 2.0 * x
1152

1153
        model = nn.Linear(3, 3)
1154
        parametrize.register_parametrization(model, "weight", RankOne())
1155
        parametrize.register_parametrization(model, "weight", Double())
1156
        hold_weight = model.weight
1157

1158
        to_model = torch.ao.nn.qat.Linear(
1159
            3, 3, qconfig=torch.ao.quantization.get_default_qconfig()
1160
        )
1161

1162
        parametrize.transfer_parametrizations_and_params(model, to_model)
1163

1164
        # checks that final and original value are correct and the to_model is parametrized
1165
        self.assertTrue(torch.nn.utils.parametrize.is_parametrized(to_model, "weight"))
1166
        self.assertEqual(model.weight, to_model.weight)
1167
        self.assertEqual(
1168
            model.parametrizations.weight.original0,
1169
            to_model.parametrizations.weight.original0,
1170
        )
1171
        self.assertEqual(
1172
            model.parametrizations.weight.original1,
1173
            to_model.parametrizations.weight.original1,
1174
        )
1175

1176
        # check that the transfer didn't affect the original value
1177
        self.assertEqual(hold_weight, model.weight)
1178

1179
        # testing that changes to one set of parametrizations do not affect the other
1180
        model.test_param = Parameter(torch.randn(3, 3))
1181

1182
        self.assertTrue(not hasattr(to_model, "test_param"))
1183
        parametrize.register_parametrization(model, "test_param", RankOne())
1184
        hold_test_param = model.test_param
1185
        parametrize.transfer_parametrizations_and_params(model, to_model, "test_param")
1186

1187
        # also check that previously missing params got transferred correctly
1188
        self.assertEqual(model.test_param, to_model.test_param)
1189
        self.assertEqual(
1190
            model.parametrizations.test_param.original0,
1191
            to_model.parametrizations.test_param.original0,
1192
        )
1193
        self.assertEqual(
1194
            model.parametrizations.test_param.original1,
1195
            to_model.parametrizations.test_param.original1,
1196
        )
1197

1198
        # check that the new transfer didn't change the value for the from_module
1199
        self.assertEqual(hold_test_param, model.test_param)
1200

1201
    @swap([True, False])
1202
    def test_new_spectral_norm(self):
1203
        with set_default_dtype(torch.double):
1204
            input = torch.randn(3, 5)
1205
            m = nn.Linear(5, 7)
1206
            m = torch.nn.utils.parametrizations.spectral_norm(m)
1207
            spectral_norm_m = m.parametrizations.weight[0]
1208

1209
            self.assertEqual(spectral_norm_m._u.size(), torch.Size([m.weight.size(0)]))
1210

1211
            # .parametrizations.weight.original should be trainable
1212
            self.assertTrue(hasattr(m.parametrizations.weight, "original"))
1213
            self.assertTrue("original" in m.parametrizations.weight._parameters)
1214

1215
            # u should be just a reused buffer
1216
            self.assertTrue(hasattr(spectral_norm_m, "_u"))
1217
            self.assertTrue("_u" in spectral_norm_m._buffers)
1218
            self.assertTrue("_v" in spectral_norm_m._buffers)
1219

1220
            # weight should be a plain attribute, not counted as a buffer or a param
1221
            self.assertIsNotNone(m.weight)
1222
            self.assertFalse("weight" in m._buffers)
1223
            self.assertFalse("weight" in m._parameters)
1224

1225
            # it should also be sharing storage as `weight_orig`
1226
            # self.assertEqual(m.parametrizations.weight.original.storage(), m.weight.storage())
1227
            self.assertEqual(m.parametrizations.weight.original.size(), m.weight.size())
1228
            self.assertEqual(
1229
                m.parametrizations.weight.original.stride(), m.weight.stride()
1230
            )
1231

1232
            m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
1233

1234
            # spectral_norm is the only parametrization
1235
            self.assertFalse(hasattr(m, "parametrizations"))
1236
            self.assertTrue("weight" in m._parameters)
1237

1238
            # We can register spectral_norm multiple times on the same parameter
1239
            # and on multiple parameters in the same module
1240
            m = torch.nn.utils.parametrizations.spectral_norm(m, "weight")
1241
            m = torch.nn.utils.parametrizations.spectral_norm(m, "weight")
1242
            m = torch.nn.utils.parametrizations.spectral_norm(m, "bias")
1243

1244
            # If we remove the parametrization on bias, weight is still parametrized
1245
            # Removing a parametrization runs forward in eval mode if leave_parametrized=True
1246
            m = torch.nn.utils.parametrize.remove_parametrizations(m, "bias")
1247
            self.assertTrue("bias" in m._parameters)
1248
            self.assertTrue(hasattr(m, "parametrizations"))
1249
            self.assertFalse("weight" in m._parameters)
1250

1251
            m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
1252
            # Neither weight and bias are parametrized
1253
            self.assertFalse(hasattr(m, "parametrizations"))
1254
            self.assertTrue("weight" in m._parameters)
1255
            self.assertFalse(torch.nn.utils.parametrize.is_parametrized(m))
1256

1257
            # test correctness in training/eval modes and cpu/multi-gpu settings
1258
            for apply_dp in (True, False):
1259
                if apply_dp:
1260
                    if not TEST_MULTIGPU:
1261
                        continue
1262
                    device = torch.device("cuda:0")
1263

1264
                    def maybe_wrap(m):
1265
                        return torch.nn.DataParallel(m, [0, 1])
1266

1267
                else:
1268
                    device = torch.device("cpu")
1269

1270
                    def maybe_wrap(m):
1271
                        return m
1272

1273
                for requires_grad in (True, False):
1274

1275
                    def get_modules():
1276
                        m = nn.Linear(3, 4).to(device)
1277
                        m.weight.requires_grad_(requires_grad)
1278
                        m = torch.nn.utils.parametrizations.spectral_norm(m)
1279
                        wrapped_m = maybe_wrap(m)
1280
                        spectral_norm_m = m.parametrizations.weight[0]
1281
                        return m, wrapped_m, spectral_norm_m
1282

1283
                    input = torch.randn(2, 3, device=device)
1284

1285
                    m, wrapped_m, spectral_norm_m = get_modules()
1286

1287
                    self.assertTrue(hasattr(spectral_norm_m, "_u"))
1288
                    u0 = spectral_norm_m._u.clone()
1289
                    v0 = spectral_norm_m._v.clone()
1290

1291
                    # TEST TRAINING BEHAVIOR
1292

1293
                    # We perform GD first to modify the initial matrix
1294
                    opt = torch.optim.SGD(wrapped_m.parameters(), lr=0.1)
1295

1296
                    opt.zero_grad()
1297
                    wrapped_m(input).sum().backward()
1298
                    opt.step()
1299

1300
                    out = wrapped_m(input)
1301
                    if requires_grad:
1302
                        # run forward again and assert that u and v are updated
1303
                        self.assertNotEqual(u0, spectral_norm_m._u)
1304
                        self.assertNotEqual(v0, spectral_norm_m._v)
1305

1306
                    # assert that backprop reaches original weight
1307
                    # can't use gradcheck because the function changes as we
1308
                    # activate through it in training mode
1309
                    if requires_grad:
1310
                        torch.autograd.grad(
1311
                            out.sum(), m.parametrizations.weight.original
1312
                        )
1313

1314
                    # test backward works with multiple forwards
1315
                    # it uses training mode so we need to reset `u` and `v` vectors
1316
                    # to same value at beginning for finite difference test to pass
1317
                    saved_u = spectral_norm_m._u.clone()
1318
                    saved_v = spectral_norm_m._v.clone()
1319

1320
                    def fn(input):
1321
                        spectral_norm_m._u.data.copy_(saved_u)
1322
                        spectral_norm_m._v.data.copy_(saved_v)
1323
                        out0 = wrapped_m(input)
1324
                        out1 = wrapped_m(input)
1325
                        return out0 + out1
1326

1327
                    # Make sure we can compute gradients wrt to all the parameters in the case
1328
                    # of double forward
1329
                    fn(input.clone().requires_grad_()).sum().backward()
1330
                    gradcheck(
1331
                        fn, (input.clone().requires_grad_(),), check_batched_grad=False
1332
                    )
1333

1334
                    # test removing
1335
                    # spectral norm module needs to be in eval mode if we'd like to
1336
                    # avoid doing another power iteration
1337
                    m, wrapped_m, _ = get_modules()
1338
                    pre_remove_out = wrapped_m(input)
1339
                    if get_swap_module_params_on_conversion():
1340
                        # When using the swap_tensors path, this is needed so that the autograd
1341
                        # graph is not alive anymore.
1342
                        pre_remove_out_ref = pre_remove_out.detach()
1343
                        del pre_remove_out
1344
                    else:
1345
                        pre_remove_out_ref = pre_remove_out
1346
                    m.eval()
1347
                    m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
1348
                    self.assertEqual(wrapped_m(input), pre_remove_out_ref)
1349

1350
                    torch.nn.utils.parametrizations.spectral_norm(m)
1351
                    for _ in range(3):
1352
                        pre_remove_out = wrapped_m(input)
1353
                    if get_swap_module_params_on_conversion():
1354
                        # When using the swap_tensors path, this is needed so that the autograd
1355
                        # graph is not alive anymore.
1356
                        pre_remove_out_ref = pre_remove_out.detach()
1357
                        del pre_remove_out
1358
                    else:
1359
                        pre_remove_out_ref = pre_remove_out
1360
                    m.eval()
1361
                    m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
1362
                    self.assertEqual(wrapped_m(input), pre_remove_out_ref)
1363

1364
                    # TEST EVAL BEHAVIOR
1365
                    m, wrapped_m, spectral_norm_m = get_modules()
1366
                    wrapped_m(input)
1367
                    last_train_out = wrapped_m(input)
1368
                    last_train_u = spectral_norm_m._u.clone()
1369
                    last_train_v = spectral_norm_m._v.clone()
1370
                    wrapped_m.zero_grad()
1371
                    wrapped_m.eval()
1372

1373
                    eval_out0 = wrapped_m(input)
1374
                    # assert eval gives same result as last training iteration
1375
                    self.assertEqual(eval_out0, last_train_out)
1376
                    # assert doing more iteartion in eval don't change things
1377
                    self.assertEqual(eval_out0, wrapped_m(input))
1378
                    self.assertEqual(last_train_u, spectral_norm_m._u)
1379
                    self.assertEqual(last_train_v, spectral_norm_m._v)
1380

1381
                    # FIXME: the code below is flaky when executed with DataParallel
1382
                    # see https://github.com/pytorch/pytorch/issues/13818
1383
                    if apply_dp:
1384
                        continue
1385

1386
                    # test backward works with multiple forwards in mixed training
1387
                    # and eval modes
1388
                    # it uses training mode so we need to reset `u` and `v` vectors
1389
                    # to same value at beginning for finite difference test to pass
1390
                    saved_u = spectral_norm_m._u.clone()
1391
                    saved_v = spectral_norm_m._v.clone()
1392

1393
                    def fn(input):
1394
                        spectral_norm_m._u.data.copy_(saved_u)
1395
                        spectral_norm_m._v.data.copy_(saved_v)
1396
                        wrapped_m.train()
1397
                        out0 = wrapped_m(input)
1398
                        wrapped_m.eval()
1399
                        out1 = wrapped_m(input)
1400
                        wrapped_m.train()
1401
                        out2 = wrapped_m(input)
1402
                        wrapped_m.eval()
1403
                        out3 = wrapped_m(input)
1404
                        return out0 + out1 + out2 + out3
1405

1406
                    gradcheck(fn, (input.clone().requires_grad_(),))
1407

1408
                    # assert that backprop reaches weight_orig in eval
1409
                    if requires_grad:
1410

1411
                        def fn(weight):
1412
                            return wrapped_m(input)
1413

1414
                        gradcheck(fn, (m.parametrizations.weight.original,))
1415

1416
    def test_register_parametrization_no_grad(self):
1417
        r"""Test that it is possible to register a parametrization without gradient"""
1418

1419
        class SplitAndCat(nn.Module):
1420
            def right_inverse(self, x):
1421
                # split the tensor in two halfs
1422
                return torch.split(x, x.shape[1] // 2)
1423

1424
            def forward(self, x0, x1):
1425
                return torch.cat([x0, x1])
1426

1427
        model = nn.Linear(8, 8)
1428

1429
        model.weight.requires_grad = False
1430
        parametrize.register_parametrization(model, "weight", SplitAndCat())
1431
        # making sure the parameterized and decomposed Tensors both have requires_grad == False
1432
        self.assertFalse(model.weight.requires_grad)
1433
        self.assertFalse(model.parametrizations.weight.original0.requires_grad)
1434
        self.assertFalse(model.parametrizations.weight.original1.requires_grad)
1435

1436
    @swap([True, False])
1437
    def test_new_spectral_norm_load_state_dict(self):
1438
        for activate_times in (0, 3):
1439
            inp = torch.randn(2, 3)
1440
            m = nn.Linear(3, 5)
1441
            snm = torch.nn.utils.parametrizations.spectral_norm(m)
1442
            snm.train()
1443

1444
            for _ in range(activate_times):
1445
                snm(inp)
1446

1447
            state_dict = deepcopy(snm.state_dict())
1448
            self.assertEqual(
1449
                {
1450
                    "parametrizations.weight.original",
1451
                    "bias",
1452
                    "parametrizations.weight.0._v",
1453
                    "parametrizations.weight.0._u",
1454
                },
1455
                set(state_dict.keys()),
1456
            )
1457

1458
            # test that non-strict loading works
1459
            non_strict_state_dict = deepcopy(state_dict)
1460
            non_strict_state_dict["nonsense"] = "nonsense"
1461
            with self.assertRaisesRegex(
1462
                RuntimeError, r'Unexpected key\(s\) in state_dict: "nonsense"'
1463
            ):
1464
                snm.load_state_dict(non_strict_state_dict, strict=True)
1465
            snm.load_state_dict(non_strict_state_dict, strict=False)
1466
            del non_strict_state_dict["parametrizations.weight.original"]
1467
            snm.load_state_dict(non_strict_state_dict, strict=False)
1468
            del non_strict_state_dict["parametrizations.weight.0._u"]
1469
            snm.load_state_dict(non_strict_state_dict, strict=False)
1470
            del non_strict_state_dict["parametrizations.weight.0._v"]
1471
            snm.load_state_dict(non_strict_state_dict, strict=False)
1472
            non_strict_state_dict[
1473
                "weight"
1474
            ] = snm.weight.detach().clone()  # set W as a buffer
1475
            snm.load_state_dict(non_strict_state_dict, strict=False)
1476
            del non_strict_state_dict._metadata[
1477
                "parametrizations.weight.0"
1478
            ]  # remove metadata info
1479
            snm.load_state_dict(non_strict_state_dict, strict=False)
1480
            del non_strict_state_dict["weight"]  # remove W buffer
1481
            snm.load_state_dict(non_strict_state_dict, strict=False)
1482
            del non_strict_state_dict["bias"]
1483
            snm.load_state_dict(non_strict_state_dict, strict=False)
1484

1485
            # normal state_dict
1486

1487
            # test that re-wrapping does not matter
1488
            m = torch.nn.utils.parametrize.remove_parametrizations(snm, "weight")
1489
            snm = torch.nn.utils.parametrizations.spectral_norm(m)
1490

1491
            snm.load_state_dict(state_dict)
1492
            with torch.no_grad():
1493
                snm.eval()
1494
                out0_eval = snm(inp)
1495
                snm.train()
1496
                out1_train = snm(inp)
1497
                out2_train = snm(inp)
1498
                snm.eval()
1499
                out3_eval = snm(inp)
1500

1501
            # test that re-wrapping does not matter
1502
            m = torch.nn.utils.parametrize.remove_parametrizations(snm, "weight")
1503
            snm = torch.nn.utils.parametrizations.spectral_norm(m)
1504

1505
            # Test normal loading
1506
            snm.load_state_dict(state_dict)
1507
            with torch.no_grad():
1508
                snm.eval()
1509
                self.assertEqual(out0_eval, snm(inp))
1510
                snm.train()
1511
                self.assertEqual(out1_train, snm(inp))
1512
                self.assertEqual(out2_train, snm(inp))
1513
                snm.eval()
1514
                self.assertEqual(out3_eval, snm(inp))
1515

1516
    @swap([True, False])
1517
    def test_new_spectral_norm_dim(self):
1518
        inp = torch.randn(2, 3, 10, 12)
1519
        m = nn.ConvTranspose2d(3, 4, (5, 6))
1520
        m = torch.nn.utils.parametrizations.spectral_norm(m)
1521
        snm = m.parametrizations.weight[0]
1522
        # this should not run into incompatible shapes
1523
        x = m(inp)
1524
        # check that u refers to the same dimension
1525
        self.assertEqual(
1526
            snm._u.shape, m.parametrizations.weight.original[0, :, 0, 0].shape
1527
        )
1528

1529
    @swap([True, False])
1530
    def test_new_spectral_norm_forward(self):
1531
        input = torch.randn(3, 5)
1532
        m = nn.Linear(5, 7)
1533
        m = torch.nn.utils.parametrizations.spectral_norm(m)
1534
        snm = m.parametrizations.weight[0]
1535
        # naive forward
1536
        _weight = m.parametrizations.weight.original
1537
        _bias, _v = m.bias, snm._v
1538
        _weight_mat = _weight.view(_weight.size(0), -1)
1539
        _u = torch.mv(_weight_mat, _v)
1540
        _u = F.normalize(_u, dim=0, eps=1e-12)
1541
        _v = torch.mv(_weight_mat.t(), _u)
1542
        _v = F.normalize(_v, dim=0, eps=1e-12)
1543
        _weight.data /= torch.dot(_u, torch.matmul(_weight_mat, _v))
1544
        out_hat = torch.nn.functional.linear(input, _weight, _bias)
1545
        expect_out = m(input)
1546
        self.assertEqual(expect_out, out_hat)
1547

1548
    @swap([True, False])
1549
    @skipIfTorchDynamo("Test does not work with TorchDynamo")
1550
    def test_new_spectral_norm_value(self):
1551
        # a test that the spectral norm (= top singular value)
1552
        # is in fact properly calculated, using example of a simple diagonal matrix.
1553
        for dtype in (torch.float, torch.cfloat):
1554
            m = nn.Linear(2, 2, dtype=dtype)
1555
            with torch.no_grad():
1556
                # set weight to be diagonal
1557
                x = torch.diagonal(m.weight)
1558
                m.weight = nn.Parameter(torch.diag(x))
1559
                torch.nn.utils.parametrizations.spectral_norm(m)
1560
                # weights should be rescaled by spectral norm, (i.e., largest diagonal element in norm)
1561
                expected = torch.diag(x / x.abs().max())
1562
                self.assertEqual(m.weight.data, expected)
1563

1564
    @skipIfNoLapack
1565
    @swap([True, False])
1566
    def test_orthogonal_parametrization(self):
1567
        # Orthogonal implements 6 algorithms (3x parametrizations times 2 options of use_trivialization)
1568

1569
        def assert_is_orthogonal(X):
1570
            n, k = X.size(-2), X.size(-1)
1571
            if n < k:
1572
                X = X.mT
1573
                n, k = k, n
1574
            Id = torch.eye(k, dtype=X.dtype, device=X.device).expand(
1575
                *(X.size()[:-2]), k, k
1576
            )
1577
            eps = 10 * n * torch.finfo(X.dtype).eps
1578
            torch.testing.assert_close(X.mH @ X, Id, atol=eps, rtol=0.0)
1579

1580
        def assert_weight_allclose_Q(weight, W):
1581
            # Test that weight is equal to the Q part of the QR decomposition of W
1582
            # (or of its transpose if the matrix is wide)
1583
            wide_matrix = W.size(-2) < W.size(-1)
1584
            if wide_matrix:
1585
                W = W.mT
1586
            Q, R = torch.linalg.qr(W)
1587
            Q *= R.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2)
1588
            if wide_matrix:
1589
                Q = Q.mT
1590
            torch.testing.assert_close(Q, weight, atol=1e-5, rtol=0.0)
1591

1592
        for shape, dtype, use_linear in product(
1593
            ((4, 4), (5, 3), (3, 5)),  # square/ tall / wide
1594
            (torch.float32, torch.complex64),
1595
            (True, False),
1596
        ):
1597
            # Conv2d does not support complex yet
1598
            if not use_linear:
1599
                continue
1600

1601
            if use_linear:
1602
                input = torch.randn(3, shape[0], dtype=dtype)
1603
            else:
1604
                input = torch.randn(2, 2, shape[0] + 2, shape[1] + 1, dtype=dtype)
1605

1606
            for parametrization, use_trivialization in product(
1607
                ("matrix_exp", "cayley", "householder"), (False, True)
1608
            ):
1609
                # right_inverse for Cayley and matrix_exp not implemented for use_trivialization=False
1610
                # See Note [right_inverse expm cayley]
1611
                can_initialize = use_trivialization or parametrization == "householder"
1612

1613
                # We generate them every time to always start with fresh weights
1614
                if use_linear:
1615
                    m = nn.Linear(*shape, dtype=dtype)
1616
                else:
1617
                    m = nn.Conv2d(2, 3, shape, dtype=dtype)
1618

1619
                # We do not support householder for complex inputs
1620
                # See Note [Householder complex]
1621

1622
                # When using the swap_tensors path, this is needed so that the autograd
1623
                # graph is not alive anymore.
1624
                if get_swap_module_params_on_conversion():
1625
                    w_init = m.weight.clone().detach()
1626
                else:
1627
                    w_init = m.weight.clone()
1628
                if parametrization == "householder" and m.weight.is_complex():
1629
                    msg = "householder parametrization does not support complex tensors"
1630
                    with self.assertRaisesRegex(ValueError, msg):
1631
                        torch.nn.utils.parametrizations.orthogonal(
1632
                            m,
1633
                            "weight",
1634
                            parametrization,
1635
                            use_trivialization=use_trivialization,
1636
                        )
1637
                    continue
1638

1639
                wide_matrix = w_init.size(-2) < w_init.size(-1)
1640
                torch.nn.utils.parametrizations.orthogonal(
1641
                    m, "weight", parametrization, use_trivialization=use_trivialization
1642
                )
1643
                # Forwards works as expected
1644
                self.assertEqual(w_init.shape, m.weight.shape)
1645
                assert_is_orthogonal(m.weight)
1646
                if can_initialize:
1647
                    assert_weight_allclose_Q(m.weight, w_init)
1648

1649
                # Intializing with a given orthogonal matrix works
1650
                X = torch.randn_like(m.weight)
1651
                if wide_matrix:
1652
                    X = X.mT
1653
                w_new = torch.linalg.qr(X).Q
1654
                if wide_matrix:
1655
                    w_new = w_new.mT
1656
                if can_initialize:
1657
                    m.weight = w_new
1658
                    torch.testing.assert_close(w_new, m.weight, atol=1e-5, rtol=0.0)
1659
                else:
1660
                    msg = (
1661
                        "assign to the matrix exponential or the Cayley parametrization"
1662
                    )
1663
                    with self.assertRaisesRegex(NotImplementedError, msg):
1664
                        m.weight = w_new
1665

1666
                # Intializing with a non-orthogonal matrix makes m.weight be the Q part of the given matrix
1667
                w_new = torch.randn_like(m.weight)
1668
                if can_initialize:
1669
                    m.weight = w_new
1670
                    assert_weight_allclose_Q(m.weight, w_new)
1671
                else:
1672
                    msg = (
1673
                        "assign to the matrix exponential or the Cayley parametrization"
1674
                    )
1675
                    with self.assertRaisesRegex(NotImplementedError, msg):
1676
                        m.weight = w_new
1677

1678
                opt = torch.optim.SGD(m.parameters(), lr=0.1)
1679
                for _ in range(2):
1680
                    opt.zero_grad()
1681
                    m(input).norm().backward()
1682
                    grad = m.parametrizations.weight.original.grad
1683
                    self.assertIsNotNone(grad)
1684
                    # We do not update the upper triangular part of the matrix if tall tril if wide
1685
                    if grad.size(-2) >= grad.size(-1):
1686
                        zeros_grad = grad.triu(1)
1687
                    else:
1688
                        zeros_grad = grad.tril(-1)
1689
                    self.assertEqual(zeros_grad, torch.zeros_like(zeros_grad))
1690
                    # The gradient in the diagonal can only be imaginary because a skew-Hermitian
1691
                    # matrix has imaginary diagonal
1692
                    diag_grad = grad.diagonal(dim1=-2, dim2=-1)
1693
                    if grad.is_complex():
1694
                        diag_grad = diag_grad.real
1695
                    self.assertEqual(diag_grad, torch.zeros_like(diag_grad))
1696
                    opt.step()
1697
                    assert_is_orthogonal(m.weight)
1698

1699
    @skipIfNoLapack
1700
    @swap([True, False])
1701
    def test_orthogonal_errors(self):
1702
        m = nn.Linear(3, 4)
1703
        with self.assertRaisesRegex(ValueError, "has to be one of"):
1704
            torch.nn.utils.parametrizations.orthogonal(m, "weight", "foo")
1705

1706
        with self.assertRaisesRegex(ValueError, "Expected a matrix"):
1707
            torch.nn.utils.parametrizations.orthogonal(m, "bias")
1708

1709
        torch.nn.utils.parametrizations.orthogonal(m, "weight")
1710
        with self.assertRaisesRegex(ValueError, "matrices of shape"):
1711
            m.weight = torch.randn(5, 5)
1712
        torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
1713

1714
    @swap([True, False])
1715
    def test_weight_norm_state_dict_compat(self):
1716
        m = nn.Linear(4, 5)
1717
        m = torch.nn.utils.weight_norm(m)
1718
        old_dict = m.state_dict()
1719

1720
        m2 = nn.Linear(4, 5)
1721
        m2 = torch.nn.utils.parametrizations.weight_norm(m2)
1722
        m2.load_state_dict(old_dict)
1723

1724
        input = torch.randn(3, 4)
1725
        self.assertEqual(m(input), m2(input))
1726

1727
    @swap([True, False])
1728
    def test_weight_norm_pickle(self):
1729
        m = nn.Linear(4, 5)
1730
        m = torch.nn.utils.parametrizations.weight_norm(m)
1731
        with self.assertRaisesRegex(RuntimeError, "state_dict"):
1732
            pickle.dumps(m)
1733

1734
    @swap([True, False])
1735
    def test_weight_norm_deepcopy(self):
1736
        m = nn.Linear(4, 5)
1737
        m = torch.nn.utils.parametrizations.weight_norm(m)
1738
        m2 = deepcopy(m)
1739
        input = torch.randn(3, 4)
1740
        self.assertEqual(m(input), m2(input))
1741

1742
    @swap([True])
1743
    def test_wrapper_subclass_parametrization(self):
1744
        class Subclassify(nn.Module):
1745
            def forward(self, X):
1746
                return TwoTensor(X, X)
1747

1748
        class UnSubclassify(nn.Module):
1749
            def forward(self, X):
1750
                return X.a
1751

1752
        class IdentityWithRightInverse(nn.Module):
1753
            def forward(self, X):
1754
                return X
1755

1756
            def right_inverse(self, X):
1757
                return TwoTensor(X, X)
1758

1759
        def _check_parametrization(
1760
            parametrization,
1761
            type_before_registration,
1762
            type_after_registration,
1763
            leave_parametrized=False,
1764
            type_after_right_inverse=None,
1765
        ):
1766
            model = nn.Linear(2, 2)
1767
            buf = torch.randn(2, 2)
1768
            model.buf = torch.nn.Buffer(buf)
1769
            if (
1770
                type_before_registration == TwoTensor
1771
                and type_after_registration == Tensor
1772
            ):
1773
                model._apply(lambda t: TwoTensor(t, t))
1774
            initial_weight = model.weight.clone().detach()
1775
            initial_weight_id = id(model.weight)
1776
            initial_buf = model.buf.clone().detach()
1777
            initial_buf_id = id(model.buf)
1778
            type_original_weight = (
1779
                type_before_registration
1780
                if type_after_right_inverse is None
1781
                else type_after_right_inverse
1782
            )
1783
            type_original_buf = (
1784
                Tensor if type_original_weight is nn.Parameter else type_original_weight
1785
            )
1786
            type_after_removal_buf = (
1787
                type_after_registration if leave_parametrized else type_original_buf
1788
            )
1789
            if leave_parametrized:
1790
                if type_after_registration is Tensor:
1791
                    type_after_removal_weight = nn.Parameter
1792
                else:
1793
                    type_after_removal_weight = type_after_registration
1794
            else:
1795
                type_after_removal_weight = type_original_weight
1796

1797
            parametrize.register_parametrization(model, "weight", parametrization())
1798
            parametrize.register_parametrization(model, "buf", parametrization())
1799
            self.assertTrue(hasattr(model, "parametrizations"))
1800
            self.assertTrue(parametrize.is_parametrized(model))
1801
            self.assertFalse(parametrize.is_parametrized(model, "bias"))
1802
            # checks for weight
1803
            self.assertTrue(parametrize.is_parametrized(model, "weight"))
1804
            self.assertTrue(
1805
                isinstance(model.parametrizations.weight.original, nn.Parameter)
1806
            )
1807
            self.assertTrue(
1808
                type(model.parametrizations.weight.original) is type_original_weight
1809
            )
1810
            self.assertNotIn("weight", model._parameters)
1811
            self.assertTrue(type(model.weight) is type_after_registration)
1812
            # checks for buf
1813
            self.assertTrue(parametrize.is_parametrized(model, "buf"))
1814
            self.assertFalse(
1815
                isinstance(model.parametrizations.buf.original, nn.Parameter)
1816
            )
1817
            self.assertTrue(
1818
                type(model.parametrizations.buf.original) is type_original_buf
1819
            )
1820
            self.assertTrue(type(model.buf) is type_after_registration)
1821
            parametrize.remove_parametrizations(
1822
                model, "weight", leave_parametrized=leave_parametrized
1823
            )
1824
            parametrize.remove_parametrizations(
1825
                model, "buf", leave_parametrized=leave_parametrized
1826
            )
1827
            self.assertFalse(hasattr(model, "parametrizations"))
1828
            self.assertEqual(model.__class__, nn.Linear)
1829
            # checks for weight
1830
            self.assertTrue(type(model.weight) is type_after_removal_weight)
1831
            self.assertTrue(isinstance(model.weight, nn.Parameter))
1832
            self.assertEqual(id(model.weight), initial_weight_id)
1833
            # checks for buf
1834
            self.assertTrue(type(model.buf) is type_after_removal_buf)
1835
            self.assertFalse(isinstance(model.buf, nn.Parameter))
1836
            self.assertEqual(id(model.buf), initial_buf_id)
1837
            if not leave_parametrized and type_after_right_inverse is None:
1838
                self.assertEqual(model.weight, initial_weight)
1839
                self.assertEqual(model.buf, initial_buf)
1840

1841
        _check_parametrization(Subclassify, nn.Parameter, TwoTensor)
1842
        _check_parametrization(UnSubclassify, TwoTensor, Tensor)
1843
        _check_parametrization(
1844
            IdentityWithRightInverse,
1845
            nn.Parameter,
1846
            TwoTensor,
1847
            type_after_right_inverse=TwoTensor,
1848
        )
1849
        _check_parametrization(
1850
            Subclassify, nn.Parameter, TwoTensor, leave_parametrized=True
1851
        )
1852
        _check_parametrization(
1853
            UnSubclassify, TwoTensor, Tensor, leave_parametrized=True
1854
        )
1855
        _check_parametrization(
1856
            IdentityWithRightInverse,
1857
            nn.Parameter,
1858
            TwoTensor,
1859
            leave_parametrized=True,
1860
            type_after_right_inverse=TwoTensor,
1861
        )
1862

1863

1864
class TestNNParametrizationDevice(NNTestCase):
1865
    @swap([True, False])
1866
    def test_weight_norm_parametrization(self, device):
1867
        for dtype in [torch.float, torch.bfloat16]:
1868
            input = torch.randn(3, 4, dtype=dtype, device=device)
1869
            m = nn.Linear(4, 5, dtype=dtype, device=device)
1870
            expected_output = m(input)
1871

1872
            # add weight normalization
1873
            m = torch.nn.utils.parametrizations.weight_norm(m)
1874
            self.assertEqual(
1875
                m.parametrizations.weight.original1.size(), m.weight.size()
1876
            )
1877
            self.assertEqual(m.parametrizations.weight.original0.size(), (5, 1))
1878
            self.assertEqual(m(input), expected_output)
1879

1880
            # remove weight norm
1881
            torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
1882
            self.assertFalse(hasattr(m, "parametrizations"))
1883
            self.assertEqual(m(input), expected_output)
1884

1885
            # test with dim=1
1886
            m = torch.nn.utils.parametrizations.weight_norm(m, dim=1)
1887
            self.assertEqual(
1888
                m.parametrizations.weight.original1.size(), m.weight.size()
1889
            )
1890
            self.assertEqual(m.parametrizations.weight.original0.size(), (1, 4))
1891
            self.assertEqual(m(input), expected_output)
1892

1893
            # test with dim=None
1894
            m = nn.Linear(4, 5, dtype=dtype, device=device)
1895
            expected_output = m(input)
1896
            m = torch.nn.utils.parametrizations.weight_norm(m, dim=None)
1897
            self.assertEqual(m(input), expected_output)
1898

1899

1900
only_for = ("cpu", "cuda")
1901
instantiate_device_type_tests(TestNNParametrizationDevice, globals(), only_for=only_for)
1902
instantiate_parametrized_tests(TestNNParametrization)
1903

1904
if __name__ == "__main__":
1905
    run_tests()
1906

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.