1
# Owner(s): ["module: nn"]
3
from copy import deepcopy
4
from itertools import product
8
import torch.nn.functional as F
9
import torch.nn.init as init
10
import torch.nn.utils.parametrize as parametrize
11
from torch import Tensor
12
from torch.__future__ import get_swap_module_params_on_conversion
13
from torch.nn import Buffer, Parameter
14
from torch.testing._internal.common_cuda import TEST_MULTIGPU
15
from torch.testing._internal.common_device_type import instantiate_device_type_tests
16
from torch.testing._internal.common_nn import NNTestCase
17
from torch.testing._internal.common_utils import (
19
instantiate_parametrized_tests,
27
from torch.testing._internal.two_tensor import TwoTensor
30
class TestNNParametrization(NNTestCase):
31
_do_cuda_memory_leak_check = True
32
_do_cuda_non_default_stream = True
34
# FIXME: Rewrite this test using functions not depending on LAPACK
35
# and remove the `@skipIfNoLapack` (see #70995)
36
# torch/nn/utils/parametrize
39
def test_register_and_remove_parametrization(self):
40
r"""Test that it is possible to add a few parametrizations
41
on a parameter or a buffer and that removing them restores the initial state
42
It also tests that backpropagating through them works as expected
45
# Define a couple matrix parametrizations
46
class Skew(nn.Module):
51
class Orthogonal(nn.Module):
54
# If X is skew-symmetric it returns an orthogonal matrix
55
Id = torch.eye(X.size(0), device=X.device)
56
# We call contiguous because solve returns a tensor with strides that are Fortran-contiguous
57
# and autograd raises a performance warning.
58
# This happens when we remove the parametrization with leave_parametrized=True,
59
# which does a set_ with a non-contiguous tensor while the gradient is contiguous
60
return torch.linalg.solve(Id + X, Id - X).contiguous()
62
class Resize(nn.Module):
66
class NoResize(nn.Module):
70
# Define a couple vector parametrizations
71
class FirstZero(nn.Module):
73
return torch.cat([x.new_zeros(1), x[1:]])
75
class LastZero(nn.Module):
77
return torch.cat([x[:-1], x.new_zeros(1)])
79
model = nn.Linear(8, 8)
80
initial_weight_id = id(model.weight)
81
initial_bias_id = id(model.bias)
82
initial_model = deepcopy(model)
85
with self.assertRaisesRegex(
87
"Registering a parametrization may not change the shape of the tensor",
89
parametrize.register_parametrization(
90
model, "weight", Resize()
91
) # default unsafe = False
92
model(torch.ones(8, 8))
94
# One parametrization with unsafe=True
95
parametrize.register_parametrization(model, "weight", Resize(), unsafe=True)
96
self.assertTrue(hasattr(model, "parametrizations"))
97
self.assertTrue(parametrize.is_parametrized(model))
98
self.assertTrue(parametrize.is_parametrized(model, "weight"))
99
self.assertFalse(parametrize.is_parametrized(model, "bias"))
100
self.assertNotIn("weight", model._parameters)
101
self.assertTrue(model.weight.shape[0] == 1)
102
parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
103
self.assertFalse(hasattr(model, "parametrizations"))
104
self.assertEqual(model.weight, initial_model.weight)
105
self.assertEqual(id(model.weight), initial_weight_id)
106
self.assertEqual(model.__class__, nn.Linear)
108
# Two parametrizations with unsafe=True
109
parametrize.register_parametrization(model, "weight", Resize(), unsafe=True)
110
parametrize.register_parametrization(model, "weight", NoResize(), unsafe=False)
111
self.assertTrue(hasattr(model, "parametrizations"))
112
self.assertTrue(parametrize.is_parametrized(model))
113
self.assertTrue(parametrize.is_parametrized(model, "weight"))
114
self.assertFalse(parametrize.is_parametrized(model, "bias"))
115
self.assertNotIn("weight", model._parameters)
116
self.assertTrue(model.weight.shape[0] == 1)
117
parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
118
self.assertFalse(hasattr(model, "parametrizations"))
119
self.assertEqual(model.weight, initial_model.weight)
120
self.assertEqual(id(model.weight), initial_weight_id)
121
self.assertEqual(model.__class__, nn.Linear)
123
# Test unsafe flag doesn't change expected behavior
124
parametrize.register_parametrization(model, "weight", Skew(), unsafe=True)
125
self.assertTrue(hasattr(model, "parametrizations"))
126
self.assertTrue(parametrize.is_parametrized(model))
127
self.assertTrue(parametrize.is_parametrized(model, "weight"))
128
self.assertFalse(parametrize.is_parametrized(model, "bias"))
129
self.assertNotIn("weight", model._parameters)
130
# Result should be skew-symmetric
132
self.assertEqual(A, -A.T)
133
if get_swap_module_params_on_conversion():
134
# When using the swap_tensors path, this is needed so that the autograd
135
# graph is not alive anymore.
137
# Remove and check consistency
138
parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
139
self.assertFalse(hasattr(model, "parametrizations"))
140
self.assertEqual(model.weight, initial_model.weight)
141
self.assertEqual(id(model.weight), initial_weight_id)
142
self.assertEqual(model.__class__, nn.Linear)
144
# Test one parametrization
145
parametrize.register_parametrization(model, "weight", Skew())
146
self.assertTrue(hasattr(model, "parametrizations"))
147
self.assertTrue(parametrize.is_parametrized(model))
148
self.assertTrue(parametrize.is_parametrized(model, "weight"))
149
self.assertFalse(parametrize.is_parametrized(model, "bias"))
150
self.assertNotIn("weight", model._parameters)
151
# Result should be skew-symmetric
153
self.assertEqual(A, -A.T)
154
if get_swap_module_params_on_conversion():
155
# When using the swap_tensors path, this is needed so that the autograd
156
# graph is not alive anymore.
158
# Remove and check consistency
159
parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
160
self.assertFalse(hasattr(model, "parametrizations"))
161
self.assertEqual(model.weight, initial_model.weight)
162
self.assertEqual(id(model.weight), initial_weight_id)
163
self.assertEqual(model.__class__, nn.Linear)
165
# Test two parametrizations at the same time and removing them
166
parametrize.register_parametrization(model, "weight", Skew())
167
parametrize.register_parametrization(model, "weight", Orthogonal())
168
# Result should be orthogonal
170
Id = torch.eye(X.size(0), device=X.device)
171
self.assertEqual(X.T @ X, Id)
172
if get_swap_module_params_on_conversion():
173
# When using the swap_tensors path, this is needed so that the autograd
174
# graph is not alive anymore.
177
self.assertTrue(hasattr(model, "parametrizations"))
178
self.assertTrue(parametrize.is_parametrized(model))
179
self.assertTrue(parametrize.is_parametrized(model, "weight"))
180
self.assertFalse(parametrize.is_parametrized(model, "bias"))
181
self.assertIn("weight", model.parametrizations)
182
self.assertNotIn("weight", model._parameters)
184
parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
185
self.assertEqual(model.weight, initial_model.weight)
186
self.assertEqual(id(model.weight), initial_weight_id)
187
self.assertFalse(hasattr(model, "parametrizations"))
188
self.assertEqual(model.__class__, nn.Linear)
191
parametrize.register_parametrization(model, "weight", Skew())
192
parametrize.register_parametrization(model, "weight", Orthogonal())
193
parametrize.register_parametrization(model, "bias", FirstZero())
194
parametrize.register_parametrization(model, "bias", LastZero())
197
self.assertTrue(parametrize.is_parametrized(model))
198
self.assertTrue(parametrize.is_parametrized(model, "weight"))
199
self.assertTrue(parametrize.is_parametrized(model, "bias"))
200
self.assertEqual(model.bias[0].item(), 0.0)
201
self.assertEqual(model.bias[-1].item(), 0.0)
203
len(list(model.parameters())), 2
204
) # Nothing weird has happpened
207
sgd = torch.optim.SGD(model.parameters(), lr=0.01)
209
weight_copy = model.weight.clone()
210
bias_copy = model.bias.clone()
212
(model.weight.T @ model.bias).sum().backward()
214
self.assertNotEqual(model.weight, weight_copy)
215
self.assertNotEqual(model.bias, bias_copy)
217
# Remove first parametrization.
218
# Check that the model is still parametrized and so is the second parameter
219
parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
220
self.assertTrue(parametrize.is_parametrized(model)) # Still parametrized
222
parametrize.is_parametrized(model, "weight")
223
) # Parametrization removed
225
parametrize.is_parametrized(model, "bias")
226
) # Still parametrized
227
self.assertEqual(model.bias[0].item(), 0.0) # Still parametrized
228
self.assertEqual(model.bias[-1].item(), 0.0) # Still parametrized
229
self.assertNotEqual(model.weight, initial_model.weight) # Has been updated
230
self.assertEqual(id(model.weight), initial_weight_id) # Keeps the same id
231
self.assertEqual(len(list(model.parameters())), 2) # Nothing weird has happened
233
weight_copy = model.weight.clone()
234
bias_copy = model.bias.clone()
236
(model.weight.T @ model.bias).sum().backward()
238
self.assertNotEqual(model.weight, weight_copy)
239
self.assertNotEqual(model.bias, bias_copy)
241
# Remove the second parametrization.
242
# Check that the module is not parametrized
243
parametrize.remove_parametrizations(model, "bias", leave_parametrized=False)
244
self.assertFalse(parametrize.is_parametrized(model)) # Not parametrized
245
self.assertNotEqual(model.bias, initial_model.bias) # Has been updated
246
self.assertNotEqual(model.bias[0].item(), 0.0) # Not parametrized
247
self.assertNotEqual(model.bias[-1].item(), 0.0) # Not parametrized
248
self.assertEqual(id(model.bias), initial_bias_id) # Keeps the same id
250
hasattr(model, "parametrizations")
251
) # Not parametrized the module
252
self.assertEqual(model.__class__, nn.Linear) # Resores the previous class
253
self.assertEqual(len(list(model.parameters())), 2) # Nothing weird has happeed
255
# Should not throw things are updated
256
weight_copy = model.weight.clone()
257
bias_copy = model.bias.clone()
259
(model.weight.T @ model.bias).sum().backward()
261
self.assertNotEqual(model.weight, weight_copy)
262
self.assertNotEqual(model.bias, bias_copy)
263
if get_swap_module_params_on_conversion():
264
# When using the swap_tensors path, this is needed so that the autograd
265
# graph is not alive anymore.
266
del weight_copy, bias_copy
268
# Test leave_parametrized=True
270
parametrize.register_parametrization(model, "weight", Skew())
271
parametrize.register_parametrization(model, "weight", Orthogonal())
272
parametrize.remove_parametrizations(
273
model, "weight", leave_parametrized=True
275
# We didn't change the dtype nor had multiple inputs, so the id should be the same
276
self.assertEqual(id(model.weight), initial_weight_id)
277
self.assertEqual(id(model.bias), initial_bias_id)
279
# Should not throw. Things are updated
280
weight_copy = model.weight.clone()
281
bias_copy = model.bias.clone()
283
(model.weight.T @ model.bias).sum().backward()
285
self.assertNotEqual(model.weight, weight_copy)
286
self.assertNotEqual(model.bias, bias_copy)
287
if get_swap_module_params_on_conversion():
288
# When using the swap_tensors path, this is needed so that the autograd
289
# graph is not alive anymore.
290
del weight_copy, bias_copy
293
def test_register_and_remove_nested_parametrization(self):
294
r"""Test that it is possible to nest the parametrizations
295
meaning that the original param is parametrized again
298
class Skew(nn.Module):
299
def forward(self, X):
303
model = nn.Linear(8, 8)
304
# Add top level parametrization
305
parametrize.register_parametrization(model, "weight", Skew())
306
self.assertTrue(hasattr(model, "parametrizations"))
307
self.assertTrue(parametrize.is_parametrized(model))
308
self.assertTrue(parametrize.is_parametrized(model, "weight"))
309
self.assertFalse(parametrize.is_parametrized(model, "bias"))
310
self.assertNotIn("weight", model._parameters)
311
# Result should be skew-symmetric
313
self.assertEqual(A, -A.T)
314
if get_swap_module_params_on_conversion():
315
# When using the swap_tensors path, this is needed so that the autograd
316
# graph is not alive anymore.
319
# Add nested parametrization
320
param_mod = model.parametrizations.weight
321
self.assertFalse(hasattr(param_mod, "parametrizations"))
322
self.assertFalse(parametrize.is_parametrized(param_mod))
323
self.assertFalse(parametrize.is_parametrized(param_mod, "original"))
325
parametrize.register_parametrization(param_mod, "original", Skew())
326
self.assertTrue(hasattr(param_mod, "parametrizations"))
327
self.assertTrue(parametrize.is_parametrized(param_mod))
328
self.assertTrue(parametrize.is_parametrized(param_mod, "original"))
329
self.assertNotIn("original", param_mod._parameters)
330
# Result should be skew-symmetric
331
A = param_mod.original
332
self.assertEqual(A, -A.T)
334
# Remove nested param and check consistency
335
parametrize.remove_parametrizations(
336
param_mod, "original", leave_parametrized=False
338
self.assertFalse(hasattr(param_mod, "parametrizations"))
339
self.assertEqual(param_mod.__class__, parametrize.ParametrizationList)
341
# Remove top level and check consistency
342
parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
343
self.assertFalse(hasattr(model, "parametrizations"))
344
self.assertEqual(model.__class__, nn.Linear)
347
def test_register_and_remove_buffer_parametrization(self):
348
r"""Test that it is possible to add and remove parametrizations on buffers"""
350
# Define a couple vector parametrizations
351
class FirstZero(nn.Module):
352
def forward(self, x):
353
return torch.cat([x.new_zeros(1), x[1:]])
355
class LastZero(nn.Module):
356
def forward(self, x):
357
return torch.cat([x[:-1], x.new_zeros(1)])
359
model = nn.Linear(8, 8)
361
# Instantiate parametrizations on buffers. It should work as expected
362
delattr(model, "bias")
363
model.bias = Buffer(torch.ones(8))
364
parametrize.register_parametrization(model, "bias", FirstZero())
365
parametrize.register_parametrization(model, "bias", LastZero())
366
self.assertTrue(parametrize.is_parametrized(model))
367
self.assertTrue(parametrize.is_parametrized(model, "bias"))
368
self.assertEqual(model.bias[0].item(), 0.0)
369
self.assertEqual(model.bias[-1].item(), 0.0)
370
self.assertTrue((model.bias[1:-1] == torch.ones(6)).all())
371
self.assertEqual(len(list(model.parameters())), 1)
373
# Remove parametrizations on buffers. It should work as expected
374
parametrize.remove_parametrizations(model, "bias", leave_parametrized=True)
375
self.assertFalse(parametrize.is_parametrized(model))
376
self.assertFalse(parametrize.is_parametrized(model, "bias"))
377
self.assertEqual(model.bias[0].item(), 0.0)
378
self.assertEqual(model.bias[-1].item(), 0.0)
379
self.assertTrue((model.bias[1:-1] == torch.ones(6)).all())
380
self.assertEqual(len(list(model.parameters())), 1)
382
# FIXME: Rewrite this test using functions not depending on LAPACK
383
# and remove the `@skipIfNoLapack` (see #70995)
386
def test_serialization_parametrization(self):
387
r"""Test that it is possible to serialize a parametrized model via state_dict"""
389
# A stateful parametrization
390
class Orthogonal(nn.Module):
391
def __init__(self, n):
393
self.id = Buffer(torch.eye(n))
394
self.B = Buffer(torch.empty(n, n))
395
init.orthogonal_(self.B)
397
def forward(self, X):
400
return self.B @ torch.linalg.solve(self.id + A, self.id - A)
403
model = torch.nn.Sequential(
404
torch.nn.Linear(5, 5),
406
torch.nn.Linear(5, 1),
409
parametrize.register_parametrization(model[0], "weight", Orthogonal(5))
414
prev_weight = model[0].weight
415
prev_B = model[0].parametrizations.weight[0].B
417
new_model = get_model()
418
with TemporaryFileName() as fname:
419
torch.save(model.state_dict(), fname)
420
new_model.load_state_dict(torch.load(fname))
423
self.assertTrue(parametrize.is_parametrized(new_model[0], "weight"))
424
self.assertEqual(prev_weight, new_model[0].weight)
425
self.assertEqual(prev_B, new_model[0].parametrizations.weight[0].B)
427
# Trying to save the whole parametrized model raises
428
with self.assertRaisesRegex(RuntimeError, "state_dict"):
429
with TemporaryFileName() as fname:
430
torch.save(model, fname)
432
# FIXME: Rewrite this test using functions not depending on LAPACK
433
# and remove the `@skipIfNoLapack` (see #70995)
436
def test_initialization_parametrization(self):
437
r"""Test that it is possible to initialize a parametrization when it
438
implements a `right_inverse` method
441
class Skew(nn.Module):
442
def forward(self, X):
446
def is_skew(self, A):
447
return torch.allclose(A, -A.T, atol=1e-6)
449
def right_inverse(self, X):
450
if not self.is_skew(X):
451
raise ValueError("The matrix is not skew-symmetric.")
454
# Implements a Cayley map where right_inverse is not quite the inverse of forward
455
class Orthogonal(nn.Module):
456
def __init__(self, n):
458
self.B = Buffer(torch.eye(n))
460
def forward(self, X):
461
Id = torch.eye(X.size(0))
462
return self.B @ torch.linalg.solve(Id + X, Id - X)
464
def is_orthogonal(self, X):
465
Id = torch.eye(X.size(0))
466
return torch.allclose(X.T @ X, Id, atol=1e-4)
468
def right_inverse(self, X):
469
if not self.is_orthogonal(X):
470
raise ValueError("The input is not orthogonal.")
471
# cayley(0) == Id, so B @ cayley(0) == B
473
return torch.zeros_like(X)
476
model = nn.Linear(N, N)
477
# Register the skew-symmetric constraint. The result is now skew-symmetric
479
# Make the weight skew-symmetric before registering the parametrization
480
with torch.no_grad():
481
model.weight.set_(skew(model.weight))
482
parametrize.register_parametrization(model, "weight", skew)
484
# X is not skew-symmetric, so it throws an error
485
with self.assertRaises(ValueError):
487
# Make X skew-symmetric
490
self.assertEqual(model.parametrizations.weight.original, X.triu(1))
491
self.assertEqual(model.weight, X)
493
# Having several parametrizations registered should work in the same way
494
parametrize.register_parametrization(model, "weight", Orthogonal(N))
495
# Register now the Cayley map. The result is now orthogonal
497
# X is not orthogonal, so it throws an error
498
with self.assertRaises(ValueError):
502
self.assertEqual(model.weight, X)
503
self.assertEqual(model.parametrizations.weight.original, torch.zeros_like(X))
506
def test_errors_unparametrized_tensor_parametrization(self):
507
# Test errors when registering a parametrization on an unparametrized tensor
508
module = nn.Linear(3, 4)
509
weight_init = module.weight.clone()
511
class Identity(nn.Module):
512
def forward(self, x):
515
# Register a parametrization on a non-existing parameter throws
516
with self.assertRaisesRegex(ValueError, "does not have a parameter"):
517
parametrize.register_parametrization(module, "foo", Identity())
518
self.assertFalse(parametrize.is_parametrized(module))
520
# Removing parametrizations from an unparametrized tensor throws
521
with self.assertRaisesRegex(ValueError, "does not have a parametrization"):
522
parametrize.remove_parametrizations(module, "bias")
523
self.assertFalse(parametrize.is_parametrized(module))
525
# A correct parametrization with several outputs
526
class Sum(nn.Module):
527
def forward(self, x, y):
530
def right_inverse(self, z):
531
return z, torch.zeros_like(z)
533
parametrize.register_parametrization(module, "weight", Sum())
534
# Cannot remove a parametrization with several outputs with `leave_parametrized=False`
535
with self.assertRaisesRegex(ValueError, "leave_parametrized=False"):
536
parametrize.remove_parametrizations(
537
module, "weight", leave_parametrized=False
539
parametrize.remove_parametrizations(module, "weight", leave_parametrized=True)
541
# A parametrization with an incorrect number of outputs
542
class WrongNumberParams(nn.Module):
543
def forward(self, x, y, z):
546
def right_inverse(self, w):
547
return w, torch.zeros_like(w)
549
# Makes param(*param.right_inverse(X)) fail
550
with self.assertRaisesRegex(TypeError, "positional argument"):
551
parametrize.register_parametrization(module, "weight", WrongNumberParams())
552
self.assertFalse(parametrize.is_parametrized(module))
554
# A parametrization with a right_inverse that does not return a Tensor or Sequence[Tensor]
555
class WrongRightInverse(Identity):
556
def right_inverse(self, z):
559
# right_inverse should return a Tensor or a Sequence[Tensor]
560
with self.assertRaisesRegex(ValueError, "Tensor or a Sequence of"):
561
parametrize.register_parametrization(module, "weight", WrongRightInverse())
562
self.assertFalse(parametrize.is_parametrized(module))
564
# If it's a sequence, it must to be a sequence of tensors
565
class WrongRightInverseSequence(nn.Module):
566
def forward(self, x, y):
569
def right_inverse(self, z):
572
with self.assertRaisesRegex(ValueError, "of the sequence with type"):
573
parametrize.register_parametrization(
574
module, "weight", WrongRightInverseSequence()
576
self.assertFalse(parametrize.is_parametrized(module))
578
# A parametrization from one tensor to one tensor that changes the dtype
579
class ChangeDtypeInverse(nn.Module):
580
def forward(self, x):
583
def right_inverse(self, w):
586
# For parametrizations that return one tensor, right_inverse may not change the dtype
587
with self.assertRaisesRegex(
588
ValueError, "outputs one tensor, it may not change the dtype"
590
parametrize.register_parametrization(module, "weight", ChangeDtypeInverse())
591
self.assertFalse(parametrize.is_parametrized(module))
593
# Doesn't return a tensor
594
class NotTensor(nn.Module):
595
def forward(self, x):
598
# Forward must return a tensor
599
with self.assertRaisesRegex(ValueError, "must return a tensor"):
600
parametrize.register_parametrization(module, "weight", NotTensor())
601
self.assertFalse(parametrize.is_parametrized(module))
603
# A parametrization from one tensor to one tensor that changes the dtype
604
class ChangeDtype(nn.Module):
605
def forward(self, x):
608
# forward should not change the initial dtype
609
with self.assertRaisesRegex(ValueError, "may not change the dtype"):
610
parametrize.register_parametrization(module, "weight", ChangeDtype())
611
self.assertFalse(parametrize.is_parametrized(module))
614
class ChangeShape(nn.Module):
615
def forward(self, x):
618
# forward should not change the original shape
619
with self.assertRaisesRegex(ValueError, "may not change the shape"):
620
parametrize.register_parametrization(module, "weight", ChangeShape())
621
self.assertFalse(parametrize.is_parametrized(module))
623
# Many to one that changes dtype
624
class ChangeDtypeMulti(nn.Module):
625
def forward(self, x, y):
626
return (x + y).bool()
628
def right_inverse(self, w):
631
# forward should not change the original shape even for parametrizations with many inputs
632
with self.assertRaisesRegex(ValueError, "may not change the dtype"):
633
parametrize.register_parametrization(module, "weight", ChangeDtypeMulti())
634
self.assertFalse(parametrize.is_parametrized(module))
636
# Returning a sequence of size one, although weird, it's correct
637
class SequenceLen1(nn.Module):
638
def forward(self, x):
641
def right_inverse(self, w):
644
parametrize.register_parametrization(module, "weight", SequenceLen1())
645
self.assertTrue(hasattr(module.parametrizations.weight, "original0"))
646
self.assertFalse(hasattr(module.parametrizations.weight, "original1"))
647
_ = module.weight # Does not throw
648
self.assertTrue(parametrize.is_parametrized(module))
649
parametrize.remove_parametrizations(module, "weight", leave_parametrized=True)
651
# None of the operations above should have altered the weight
652
self.assertFalse(parametrize.is_parametrized(module))
653
self.assertEqual(module.weight, weight_init)
656
def test_errors_parametrized_tensor_parametrization(self):
657
# Test errors when registering a parametrization on a parametrized tensor
659
class Identity(nn.Module):
660
def forward(self, x):
663
module = nn.Linear(3, 4)
664
parametrize.register_parametrization(module, "weight", Identity())
666
# Has to return a tensor
667
class WrongReturn(nn.Module):
668
def forward(self, x):
671
with self.assertRaisesRegex(ValueError, "must return a tensor"):
672
parametrize.register_parametrization(module, "weight", WrongReturn())
673
self.assertTrue(parametrize.is_parametrized(module))
674
self.assertEqual(len(module.parametrizations.weight), 1)
675
self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))
677
# Cannot change dtype
678
class ChangeDtype(nn.Module):
679
def forward(self, x):
682
with self.assertRaisesRegex(ValueError, "may not change the dtype"):
683
parametrize.register_parametrization(module, "weight", ChangeDtype())
684
self.assertTrue(parametrize.is_parametrized(module))
685
self.assertEqual(len(module.parametrizations.weight), 1)
686
self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))
688
# Cannot change shape
689
class ChangeShape(nn.Module):
690
def forward(self, x):
693
with self.assertRaisesRegex(ValueError, "may not change the shape"):
694
parametrize.register_parametrization(module, "weight", ChangeShape())
695
self.assertTrue(parametrize.is_parametrized(module))
696
self.assertEqual(len(module.parametrizations.weight), 1)
697
self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))
699
# The following checks are mostly due to bugs in the code of the parametrization
701
# right_inverse has to return a tensor
702
class WrongReturnInverse(Identity):
703
def right_inverse(self, x):
706
with self.assertRaisesRegex(ValueError, "right_inverse must return a tensor"):
707
parametrize.register_parametrization(module, "weight", WrongReturnInverse())
708
self.assertTrue(parametrize.is_parametrized(module))
709
self.assertEqual(len(module.parametrizations.weight), 1)
710
self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))
712
# Cannot change dtype
713
class ChangeDtypeInverse(Identity):
714
def right_inverse(self, x):
717
with self.assertRaisesRegex(ValueError, "must have the same dtype"):
718
parametrize.register_parametrization(module, "weight", ChangeDtypeInverse())
719
self.assertTrue(parametrize.is_parametrized(module))
720
self.assertEqual(len(module.parametrizations.weight), 1)
721
self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))
723
# Cannot change shape
724
class ChangeShapeInverse(Identity):
725
def right_inverse(self, x):
728
with self.assertRaisesRegex(ValueError, "must have the same shape"):
729
parametrize.register_parametrization(module, "weight", ChangeShapeInverse())
730
self.assertTrue(parametrize.is_parametrized(module))
731
self.assertEqual(len(module.parametrizations.weight), 1)
732
self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))
734
# FIXME: Rewrite this test using functions not depending on LAPACK
735
# and remove the `@skipIfNoLapack` (see #70995)
738
def test_multiple_inputs_parametrization(self):
739
# A parametrization with several outputs
740
class RankOne(nn.Module):
741
def forward(self, x, y):
742
# Form a rank-1 matrix from a pair of vectors
743
return x.unsqueeze(-1) @ y.unsqueeze(-2)
745
def right_inverse(self, Y):
746
# We project the given matrix onto the rank 1 matrices
747
U, S, Vh = torch.linalg.svd(Y, full_matrices=False)
748
# S is ordered in a decreasing way.
749
s0_sqrt = S[0].sqrt().unsqueeze(-1)
750
return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt
752
# Simple parametrisation
753
class Double(nn.Module):
754
def forward(self, x):
757
def right_inverse(self, w):
760
model = nn.Linear(3, 3)
761
# Test one parametrization
762
parametrize.register_parametrization(model, "weight", RankOne())
763
self.assertTrue(hasattr(model, "parametrizations"))
764
self.assertTrue(parametrize.is_parametrized(model))
765
self.assertTrue(parametrize.is_parametrized(model, "weight"))
766
self.assertTrue(hasattr(model.parametrizations.weight, "original0"))
767
self.assertIn("original0", model.parametrizations.weight._parameters)
768
self.assertTrue(hasattr(model.parametrizations.weight, "original1"))
769
self.assertIn("original1", model.parametrizations.weight._parameters)
770
self.assertFalse(parametrize.is_parametrized(model, "bias"))
771
self.assertNotIn("weight", model._parameters)
772
# Result should be rank 1
773
self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1)
775
with self.assertRaisesRegex(ValueError, "leave_parametrized=False"):
776
# Cannot remove a parametrization with multiple inputs and not leave it parametrized
777
parametrize.remove_parametrizations(
778
model, "weight", leave_parametrized=False
780
# Remove parametrization and check consistency
781
parametrize.remove_parametrizations(model, "weight", leave_parametrized=True)
782
self.assertFalse(hasattr(model, "parametrizations"))
783
self.assertEqual(model.__class__, nn.Linear)
784
self.assertFalse(parametrize.is_parametrized(model))
785
self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1)
786
self.assertIn("weight", model._parameters)
788
# Registering parametrizations with one input on top of one with multiple inputs should work
789
init_weight = model.weight.clone()
790
parametrize.register_parametrization(model, "weight", RankOne())
791
# Projecting a rank 1 matrix onto the matrices of rank one does not change the matrix
792
self.assertEqual(init_weight, model.weight)
793
parametrize.register_parametrization(model, "weight", Double())
794
# The matrix now is twice the initial matrix
795
self.assertEqual(2.0 * init_weight, model.weight)
796
# Multiplying by a scalar does not change the rank
797
self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1)
799
# The model has now three parameters
800
self.assertEqual(len(list(model.parameters())), 3)
802
sgd = torch.optim.SGD(model.parameters(), lr=0.1)
804
# Test backward. Should not throw
807
loss = (model.weight.T @ model.bias).sum()
811
# Same drill as before, removing should work as expected
812
with self.assertRaisesRegex(ValueError, "leave_parametrized=False"):
813
# Cannot remove a parametrization with multiple inputs and not leave it parametrized
814
parametrize.remove_parametrizations(
815
model, "weight", leave_parametrized=False
817
# Remove parametrization and check consistency
818
parametrize.remove_parametrizations(model, "weight", leave_parametrized=True)
819
self.assertFalse(hasattr(model, "parametrizations"))
820
self.assertEqual(model.__class__, nn.Linear)
821
self.assertFalse(parametrize.is_parametrized(model))
822
self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1)
823
self.assertIn("weight", model._parameters)
825
# The model has now two parameters
826
self.assertEqual(len(list(model.parameters())), 2)
828
# Test backward. Should not throw
829
sgd = torch.optim.SGD(model.parameters(), lr=0.1)
832
loss = (model.weight.T @ model.bias).sum()
836
# FIXME: Rewrite this test using functions not depending on LAPACK
837
# and remove the `@skipIfNoLapack` (see #70995)
840
def test_caching_parametrization(self):
841
r"""Test the caching system of a parametrization"""
843
# Define a couple matrix parametrizations
844
class Skew(nn.Module):
845
def forward(self, X):
849
class Orthogonal(nn.Module):
850
def forward(self, X):
851
Id = torch.eye(X.size(0), device=X.device)
852
return torch.linalg.solve(Id + X, Id - X)
854
model = nn.Linear(5, 5)
855
parametrize.register_parametrization(model, "weight", Skew())
856
parametrize.register_parametrization(model, "weight", Orthogonal())
858
# Test that the caching system works
859
with parametrize.cached():
862
self.assertEqual(id(X), id(Y))
864
# FIXME: Rewrite this test using functions not depending on LAPACK
865
# and remove the `@skipIfNoLapack` (see #70995)
868
def test_caching_parametrization_with_transfer_parametrizations_and_params(self):
869
r"""Test that transferring parametrizations doesn't cause issues with caching"""
871
class Skew(nn.Module):
872
def forward(self, X):
876
class Orthogonal(nn.Module):
877
def forward(self, X):
878
Id = torch.eye(X.size(0), device=X.device)
879
return torch.linalg.solve(Id + X, Id - X)
881
model = nn.Linear(5, 5)
882
parametrize.register_parametrization(model, "weight", Skew())
883
parametrize.register_parametrization(model, "weight", Orthogonal())
885
to_model = nn.Linear(5, 5)
886
parametrize.transfer_parametrizations_and_params(model, to_model)
888
with parametrize.cached():
891
self.assertEqual(id(X), id(Y))
895
self.assertEqual(id(A), id(B))
897
# test that the results are distinct objects for each module
898
self.assertNotEqual(id(A), id(X))
901
def test_parametrization_same_training_mode(self):
902
r"""Test training mode updated on parametrization registration"""
904
class Identity(nn.Module):
905
def forward(self, X):
908
module = nn.Linear(4, 4)
910
parametrize.register_parametrization(module, "weight", Identity())
911
self.assertFalse(module.parametrizations.weight[0].training)
913
parametrize.register_parametrization(module, "weight", Identity().eval())
914
self.assertTrue(module.parametrizations.weight[0].training)
915
self.assertTrue(module.parametrizations.weight[1].training)
918
def test_type_before_parametrizations(self):
919
r"""Test that type_before_parametrizations always retrieves original type"""
921
class Identity(nn.Module):
922
def forward(self, X):
925
model = nn.Linear(5, 5)
926
original_type = type(model)
928
parametrize.type_before_parametrizations(model) == original_type
930
parametrize.register_parametrization(model, "weight", Identity())
932
parametrize.type_before_parametrizations(model) == original_type
936
def test_deepcopy_after_parametrization(self):
937
r"""Test that we are able to create a deepcopy of the module when it's parametrized."""
939
class AddOne(nn.Module):
940
def forward(self, x):
943
class ModelWithoutDeepcopy(nn.Module):
944
def __init__(self) -> None:
946
self.weight = nn.Parameter(
947
torch.tensor([1.0, 1.0, 1.0, 1.0]), requires_grad=True
949
self.bias = nn.Parameter(
950
torch.tensor([0.0, 0.0, 0.0, 0.0]), requires_grad=True
952
self.attr = [1.0, 2.0, 3.0, 4.0]
954
class ActualModel(ModelWithoutDeepcopy):
955
# Emulate custom implementation of the deepcopying.
956
def __deepcopy__(self, memo):
957
result = self.__new__(self.__class__)
958
memo[id(self)] = result
959
result.__dict__ = deepcopy(self.__dict__, memo)
962
def check_deepcopy(m1: nn.Module, m2: nn.Module):
963
w1 = m1.parametrizations.weight.original
964
w2 = m2.parametrizations.weight.original
966
m1.parametrizations.bias.original
967
if parametrize.is_parametrized(m1, "bias")
971
m2.parametrizations.bias.original
972
if parametrize.is_parametrized(m2, "bias")
975
# Weights, biases and attributes should be equal but they must be different objects.
976
self.assertEqual(m1.__dict__.keys(), m2.__dict__.keys())
977
self.assertIsNot(m1, m2)
978
self.assertEqual(w1, w2)
979
self.assertIsNot(w1, w2)
980
self.assertEqual(b1, b2)
981
self.assertIsNot(b1, b2)
982
self.assertEqual(m1.attr, m2.attr)
983
self.assertIsNot(m1.attr, m2.attr)
985
for model in (ModelWithoutDeepcopy(), ActualModel()):
986
# General check that we are able to create deepcopy.
987
parametrize.register_parametrization(model, "weight", AddOne())
988
check_deepcopy(model, deepcopy(model))
989
# Check that this works on models with several parametrized tensors.
990
parametrize.register_parametrization(model, "bias", AddOne())
991
check_deepcopy(model, deepcopy(model))
992
# Check that this works on models where tensors have more than one parametrization.
993
parametrize.register_parametrization(model, "weight", AddOne())
994
check_deepcopy(model, deepcopy(model))
997
def test_transfer_parametrizations_and_params(self):
998
r"""Test that all parametrizations and their associated parameters are transferred."""
1000
class AddOne(nn.Module):
1001
def forward(self, x):
1004
class Double(nn.Module):
1005
def forward(self, x):
1008
def right_inverse(self, x):
1011
class MinusOne(nn.Module):
1012
def forward(self, x):
1015
model = nn.Linear(5, 5)
1016
parametrize.register_parametrization(model, "weight", AddOne())
1017
parametrize.register_parametrization(model, "weight", Double())
1018
parametrize.register_parametrization(model, "weight", MinusOne())
1019
hold_weight = model.weight
1021
to_model = torch.ao.nn.qat.Linear(
1022
5, 5, qconfig=torch.ao.quantization.get_default_qconfig()
1024
parametrize.transfer_parametrizations_and_params(model, to_model)
1026
# checks that final and original value are correct and the to_model is parametrized
1027
self.assertTrue(torch.nn.utils.parametrize.is_parametrized(to_model, "weight"))
1028
self.assertEqual(model.weight, to_model.weight)
1030
model.parametrizations.weight.original,
1031
to_model.parametrizations.weight.original,
1034
# check that the transfer didn't affect the original value
1035
self.assertEqual(hold_weight, model.weight)
1036
if get_swap_module_params_on_conversion():
1037
# When using the swap_tensors path, this is needed so that the autograd
1038
# graph is not alive anymore.
1041
# testing that changes to one set of parametrizations do not affect the other
1042
parametrize.remove_parametrizations(to_model, "weight")
1043
self.assertFalse(torch.nn.utils.parametrize.is_parametrized(to_model, "weight"))
1044
self.assertTrue(torch.nn.utils.parametrize.is_parametrized(model, "weight"))
1046
# also test that parameters that don't exist in to_model get transferred
1047
model.test_param = Parameter(torch.randn(5, 5))
1049
self.assertTrue(not hasattr(to_model, "test_param"))
1050
parametrize.register_parametrization(model, "test_param", Double())
1051
hold_test_param = model.test_param
1052
parametrize.transfer_parametrizations_and_params(model, to_model, "test_param")
1054
# check that previously missing params got transferred correctly
1055
self.assertEqual(model.test_param, to_model.test_param)
1057
model.parametrizations.test_param.original,
1058
to_model.parametrizations.test_param.original,
1061
# check that the new transfer didn't change the value for the from_module
1062
self.assertEqual(hold_test_param, model.test_param)
1064
@swap([True, False])
1065
def test_transfer_parametrizations_and_params_right_inverse(self):
1066
r"""Test that all parametrizations and their associated parameters are transferred."""
1068
class Double(nn.Module):
1069
def forward(self, x):
1072
def right_inverse(self, x):
1075
model = nn.Linear(5, 5)
1076
parametrize.register_parametrization(model, "weight", Double())
1077
hold_weight = model.weight
1079
to_model = torch.ao.nn.qat.Linear(
1080
5, 5, qconfig=torch.ao.quantization.get_default_qconfig()
1082
parametrize.transfer_parametrizations_and_params(model, to_model)
1084
# check that transfer occurs successfully
1085
self.assertEqual(model.weight, to_model.weight)
1087
model.parametrizations.weight.original,
1088
to_model.parametrizations.weight.original,
1091
# check that transfer doesn't affect the from_model weight
1092
self.assertEqual(hold_weight, model.weight)
1094
@swap([True, False])
1095
def test_transfer_parametrizations_and_params_single_param(self):
1096
r"""Test that all parametrizations and their associated parameters are transferred."""
1098
class AddOne(nn.Module):
1099
def forward(self, x):
1102
class Double(nn.Module):
1103
def forward(self, x):
1106
class MinusOne(nn.Module):
1107
def forward(self, x):
1110
model = nn.Linear(5, 5, bias=True)
1111
parametrize.register_parametrization(model, "weight", AddOne())
1112
parametrize.register_parametrization(model, "weight", Double())
1113
parametrize.register_parametrization(model, "weight", MinusOne())
1114
parametrize.register_parametrization(model, "bias", AddOne())
1115
parametrize.register_parametrization(model, "bias", Double())
1116
parametrize.register_parametrization(model, "bias", MinusOne())
1118
to_model = torch.ao.nn.qat.Linear(
1119
5, 5, bias=True, qconfig=torch.ao.quantization.get_default_qconfig()
1121
parametrize.transfer_parametrizations_and_params(model, to_model, "weight")
1123
# check that weight and only weight was transferred
1124
self.assertEqual(model.weight, to_model.weight)
1126
model.parametrizations.weight.original,
1127
to_model.parametrizations.weight.original,
1129
self.assertTrue("bias" not in to_model.parametrizations)
1131
# FIXME: Rewrite this test using functions not depending on LAPACK
1132
# and remove the `@skipIfNoLapack` (see #70995)
1134
@swap([True, False])
1135
def test_transfer_parametrizations_and_params_many_to_one(self):
1136
# A parametrization with several outputs
1137
class RankOne(nn.Module):
1138
def forward(self, x, y):
1139
# Form a rank-1 matrix from a pair of vectors
1140
return x.unsqueeze(-1) @ y.unsqueeze(-2)
1142
def right_inverse(self, Y):
1143
# We project the given matrix onto the rank 1 matrices
1144
U, S, Vh = torch.linalg.svd(Y, full_matrices=False)
1145
# S is ordered in a decreasing way.
1146
s0_sqrt = S[0].sqrt().unsqueeze(-1)
1147
return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt
1149
class Double(nn.Module):
1150
def forward(self, x):
1153
model = nn.Linear(3, 3)
1154
parametrize.register_parametrization(model, "weight", RankOne())
1155
parametrize.register_parametrization(model, "weight", Double())
1156
hold_weight = model.weight
1158
to_model = torch.ao.nn.qat.Linear(
1159
3, 3, qconfig=torch.ao.quantization.get_default_qconfig()
1162
parametrize.transfer_parametrizations_and_params(model, to_model)
1164
# checks that final and original value are correct and the to_model is parametrized
1165
self.assertTrue(torch.nn.utils.parametrize.is_parametrized(to_model, "weight"))
1166
self.assertEqual(model.weight, to_model.weight)
1168
model.parametrizations.weight.original0,
1169
to_model.parametrizations.weight.original0,
1172
model.parametrizations.weight.original1,
1173
to_model.parametrizations.weight.original1,
1176
# check that the transfer didn't affect the original value
1177
self.assertEqual(hold_weight, model.weight)
1179
# testing that changes to one set of parametrizations do not affect the other
1180
model.test_param = Parameter(torch.randn(3, 3))
1182
self.assertTrue(not hasattr(to_model, "test_param"))
1183
parametrize.register_parametrization(model, "test_param", RankOne())
1184
hold_test_param = model.test_param
1185
parametrize.transfer_parametrizations_and_params(model, to_model, "test_param")
1187
# also check that previously missing params got transferred correctly
1188
self.assertEqual(model.test_param, to_model.test_param)
1190
model.parametrizations.test_param.original0,
1191
to_model.parametrizations.test_param.original0,
1194
model.parametrizations.test_param.original1,
1195
to_model.parametrizations.test_param.original1,
1198
# check that the new transfer didn't change the value for the from_module
1199
self.assertEqual(hold_test_param, model.test_param)
1201
@swap([True, False])
1202
def test_new_spectral_norm(self):
1203
with set_default_dtype(torch.double):
1204
input = torch.randn(3, 5)
1206
m = torch.nn.utils.parametrizations.spectral_norm(m)
1207
spectral_norm_m = m.parametrizations.weight[0]
1209
self.assertEqual(spectral_norm_m._u.size(), torch.Size([m.weight.size(0)]))
1211
# .parametrizations.weight.original should be trainable
1212
self.assertTrue(hasattr(m.parametrizations.weight, "original"))
1213
self.assertTrue("original" in m.parametrizations.weight._parameters)
1215
# u should be just a reused buffer
1216
self.assertTrue(hasattr(spectral_norm_m, "_u"))
1217
self.assertTrue("_u" in spectral_norm_m._buffers)
1218
self.assertTrue("_v" in spectral_norm_m._buffers)
1220
# weight should be a plain attribute, not counted as a buffer or a param
1221
self.assertIsNotNone(m.weight)
1222
self.assertFalse("weight" in m._buffers)
1223
self.assertFalse("weight" in m._parameters)
1225
# it should also be sharing storage as `weight_orig`
1226
# self.assertEqual(m.parametrizations.weight.original.storage(), m.weight.storage())
1227
self.assertEqual(m.parametrizations.weight.original.size(), m.weight.size())
1229
m.parametrizations.weight.original.stride(), m.weight.stride()
1232
m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
1234
# spectral_norm is the only parametrization
1235
self.assertFalse(hasattr(m, "parametrizations"))
1236
self.assertTrue("weight" in m._parameters)
1238
# We can register spectral_norm multiple times on the same parameter
1239
# and on multiple parameters in the same module
1240
m = torch.nn.utils.parametrizations.spectral_norm(m, "weight")
1241
m = torch.nn.utils.parametrizations.spectral_norm(m, "weight")
1242
m = torch.nn.utils.parametrizations.spectral_norm(m, "bias")
1244
# If we remove the parametrization on bias, weight is still parametrized
1245
# Removing a parametrization runs forward in eval mode if leave_parametrized=True
1246
m = torch.nn.utils.parametrize.remove_parametrizations(m, "bias")
1247
self.assertTrue("bias" in m._parameters)
1248
self.assertTrue(hasattr(m, "parametrizations"))
1249
self.assertFalse("weight" in m._parameters)
1251
m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
1252
# Neither weight and bias are parametrized
1253
self.assertFalse(hasattr(m, "parametrizations"))
1254
self.assertTrue("weight" in m._parameters)
1255
self.assertFalse(torch.nn.utils.parametrize.is_parametrized(m))
1257
# test correctness in training/eval modes and cpu/multi-gpu settings
1258
for apply_dp in (True, False):
1260
if not TEST_MULTIGPU:
1262
device = torch.device("cuda:0")
1265
return torch.nn.DataParallel(m, [0, 1])
1268
device = torch.device("cpu")
1273
for requires_grad in (True, False):
1276
m = nn.Linear(3, 4).to(device)
1277
m.weight.requires_grad_(requires_grad)
1278
m = torch.nn.utils.parametrizations.spectral_norm(m)
1279
wrapped_m = maybe_wrap(m)
1280
spectral_norm_m = m.parametrizations.weight[0]
1281
return m, wrapped_m, spectral_norm_m
1283
input = torch.randn(2, 3, device=device)
1285
m, wrapped_m, spectral_norm_m = get_modules()
1287
self.assertTrue(hasattr(spectral_norm_m, "_u"))
1288
u0 = spectral_norm_m._u.clone()
1289
v0 = spectral_norm_m._v.clone()
1291
# TEST TRAINING BEHAVIOR
1293
# We perform GD first to modify the initial matrix
1294
opt = torch.optim.SGD(wrapped_m.parameters(), lr=0.1)
1297
wrapped_m(input).sum().backward()
1300
out = wrapped_m(input)
1302
# run forward again and assert that u and v are updated
1303
self.assertNotEqual(u0, spectral_norm_m._u)
1304
self.assertNotEqual(v0, spectral_norm_m._v)
1306
# assert that backprop reaches original weight
1307
# can't use gradcheck because the function changes as we
1308
# activate through it in training mode
1310
torch.autograd.grad(
1311
out.sum(), m.parametrizations.weight.original
1314
# test backward works with multiple forwards
1315
# it uses training mode so we need to reset `u` and `v` vectors
1316
# to same value at beginning for finite difference test to pass
1317
saved_u = spectral_norm_m._u.clone()
1318
saved_v = spectral_norm_m._v.clone()
1321
spectral_norm_m._u.data.copy_(saved_u)
1322
spectral_norm_m._v.data.copy_(saved_v)
1323
out0 = wrapped_m(input)
1324
out1 = wrapped_m(input)
1327
# Make sure we can compute gradients wrt to all the parameters in the case
1329
fn(input.clone().requires_grad_()).sum().backward()
1331
fn, (input.clone().requires_grad_(),), check_batched_grad=False
1335
# spectral norm module needs to be in eval mode if we'd like to
1336
# avoid doing another power iteration
1337
m, wrapped_m, _ = get_modules()
1338
pre_remove_out = wrapped_m(input)
1339
if get_swap_module_params_on_conversion():
1340
# When using the swap_tensors path, this is needed so that the autograd
1341
# graph is not alive anymore.
1342
pre_remove_out_ref = pre_remove_out.detach()
1345
pre_remove_out_ref = pre_remove_out
1347
m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
1348
self.assertEqual(wrapped_m(input), pre_remove_out_ref)
1350
torch.nn.utils.parametrizations.spectral_norm(m)
1352
pre_remove_out = wrapped_m(input)
1353
if get_swap_module_params_on_conversion():
1354
# When using the swap_tensors path, this is needed so that the autograd
1355
# graph is not alive anymore.
1356
pre_remove_out_ref = pre_remove_out.detach()
1359
pre_remove_out_ref = pre_remove_out
1361
m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
1362
self.assertEqual(wrapped_m(input), pre_remove_out_ref)
1364
# TEST EVAL BEHAVIOR
1365
m, wrapped_m, spectral_norm_m = get_modules()
1367
last_train_out = wrapped_m(input)
1368
last_train_u = spectral_norm_m._u.clone()
1369
last_train_v = spectral_norm_m._v.clone()
1370
wrapped_m.zero_grad()
1373
eval_out0 = wrapped_m(input)
1374
# assert eval gives same result as last training iteration
1375
self.assertEqual(eval_out0, last_train_out)
1376
# assert doing more iteartion in eval don't change things
1377
self.assertEqual(eval_out0, wrapped_m(input))
1378
self.assertEqual(last_train_u, spectral_norm_m._u)
1379
self.assertEqual(last_train_v, spectral_norm_m._v)
1381
# FIXME: the code below is flaky when executed with DataParallel
1382
# see https://github.com/pytorch/pytorch/issues/13818
1386
# test backward works with multiple forwards in mixed training
1388
# it uses training mode so we need to reset `u` and `v` vectors
1389
# to same value at beginning for finite difference test to pass
1390
saved_u = spectral_norm_m._u.clone()
1391
saved_v = spectral_norm_m._v.clone()
1394
spectral_norm_m._u.data.copy_(saved_u)
1395
spectral_norm_m._v.data.copy_(saved_v)
1397
out0 = wrapped_m(input)
1399
out1 = wrapped_m(input)
1401
out2 = wrapped_m(input)
1403
out3 = wrapped_m(input)
1404
return out0 + out1 + out2 + out3
1406
gradcheck(fn, (input.clone().requires_grad_(),))
1408
# assert that backprop reaches weight_orig in eval
1412
return wrapped_m(input)
1414
gradcheck(fn, (m.parametrizations.weight.original,))
1416
def test_register_parametrization_no_grad(self):
1417
r"""Test that it is possible to register a parametrization without gradient"""
1419
class SplitAndCat(nn.Module):
1420
def right_inverse(self, x):
1421
# split the tensor in two halfs
1422
return torch.split(x, x.shape[1] // 2)
1424
def forward(self, x0, x1):
1425
return torch.cat([x0, x1])
1427
model = nn.Linear(8, 8)
1429
model.weight.requires_grad = False
1430
parametrize.register_parametrization(model, "weight", SplitAndCat())
1431
# making sure the parameterized and decomposed Tensors both have requires_grad == False
1432
self.assertFalse(model.weight.requires_grad)
1433
self.assertFalse(model.parametrizations.weight.original0.requires_grad)
1434
self.assertFalse(model.parametrizations.weight.original1.requires_grad)
1436
@swap([True, False])
1437
def test_new_spectral_norm_load_state_dict(self):
1438
for activate_times in (0, 3):
1439
inp = torch.randn(2, 3)
1441
snm = torch.nn.utils.parametrizations.spectral_norm(m)
1444
for _ in range(activate_times):
1447
state_dict = deepcopy(snm.state_dict())
1450
"parametrizations.weight.original",
1452
"parametrizations.weight.0._v",
1453
"parametrizations.weight.0._u",
1455
set(state_dict.keys()),
1458
# test that non-strict loading works
1459
non_strict_state_dict = deepcopy(state_dict)
1460
non_strict_state_dict["nonsense"] = "nonsense"
1461
with self.assertRaisesRegex(
1462
RuntimeError, r'Unexpected key\(s\) in state_dict: "nonsense"'
1464
snm.load_state_dict(non_strict_state_dict, strict=True)
1465
snm.load_state_dict(non_strict_state_dict, strict=False)
1466
del non_strict_state_dict["parametrizations.weight.original"]
1467
snm.load_state_dict(non_strict_state_dict, strict=False)
1468
del non_strict_state_dict["parametrizations.weight.0._u"]
1469
snm.load_state_dict(non_strict_state_dict, strict=False)
1470
del non_strict_state_dict["parametrizations.weight.0._v"]
1471
snm.load_state_dict(non_strict_state_dict, strict=False)
1472
non_strict_state_dict[
1474
] = snm.weight.detach().clone() # set W as a buffer
1475
snm.load_state_dict(non_strict_state_dict, strict=False)
1476
del non_strict_state_dict._metadata[
1477
"parametrizations.weight.0"
1478
] # remove metadata info
1479
snm.load_state_dict(non_strict_state_dict, strict=False)
1480
del non_strict_state_dict["weight"] # remove W buffer
1481
snm.load_state_dict(non_strict_state_dict, strict=False)
1482
del non_strict_state_dict["bias"]
1483
snm.load_state_dict(non_strict_state_dict, strict=False)
1487
# test that re-wrapping does not matter
1488
m = torch.nn.utils.parametrize.remove_parametrizations(snm, "weight")
1489
snm = torch.nn.utils.parametrizations.spectral_norm(m)
1491
snm.load_state_dict(state_dict)
1492
with torch.no_grad():
1494
out0_eval = snm(inp)
1496
out1_train = snm(inp)
1497
out2_train = snm(inp)
1499
out3_eval = snm(inp)
1501
# test that re-wrapping does not matter
1502
m = torch.nn.utils.parametrize.remove_parametrizations(snm, "weight")
1503
snm = torch.nn.utils.parametrizations.spectral_norm(m)
1505
# Test normal loading
1506
snm.load_state_dict(state_dict)
1507
with torch.no_grad():
1509
self.assertEqual(out0_eval, snm(inp))
1511
self.assertEqual(out1_train, snm(inp))
1512
self.assertEqual(out2_train, snm(inp))
1514
self.assertEqual(out3_eval, snm(inp))
1516
@swap([True, False])
1517
def test_new_spectral_norm_dim(self):
1518
inp = torch.randn(2, 3, 10, 12)
1519
m = nn.ConvTranspose2d(3, 4, (5, 6))
1520
m = torch.nn.utils.parametrizations.spectral_norm(m)
1521
snm = m.parametrizations.weight[0]
1522
# this should not run into incompatible shapes
1524
# check that u refers to the same dimension
1526
snm._u.shape, m.parametrizations.weight.original[0, :, 0, 0].shape
1529
@swap([True, False])
1530
def test_new_spectral_norm_forward(self):
1531
input = torch.randn(3, 5)
1533
m = torch.nn.utils.parametrizations.spectral_norm(m)
1534
snm = m.parametrizations.weight[0]
1536
_weight = m.parametrizations.weight.original
1537
_bias, _v = m.bias, snm._v
1538
_weight_mat = _weight.view(_weight.size(0), -1)
1539
_u = torch.mv(_weight_mat, _v)
1540
_u = F.normalize(_u, dim=0, eps=1e-12)
1541
_v = torch.mv(_weight_mat.t(), _u)
1542
_v = F.normalize(_v, dim=0, eps=1e-12)
1543
_weight.data /= torch.dot(_u, torch.matmul(_weight_mat, _v))
1544
out_hat = torch.nn.functional.linear(input, _weight, _bias)
1545
expect_out = m(input)
1546
self.assertEqual(expect_out, out_hat)
1548
@swap([True, False])
1549
@skipIfTorchDynamo("Test does not work with TorchDynamo")
1550
def test_new_spectral_norm_value(self):
1551
# a test that the spectral norm (= top singular value)
1552
# is in fact properly calculated, using example of a simple diagonal matrix.
1553
for dtype in (torch.float, torch.cfloat):
1554
m = nn.Linear(2, 2, dtype=dtype)
1555
with torch.no_grad():
1556
# set weight to be diagonal
1557
x = torch.diagonal(m.weight)
1558
m.weight = nn.Parameter(torch.diag(x))
1559
torch.nn.utils.parametrizations.spectral_norm(m)
1560
# weights should be rescaled by spectral norm, (i.e., largest diagonal element in norm)
1561
expected = torch.diag(x / x.abs().max())
1562
self.assertEqual(m.weight.data, expected)
1565
@swap([True, False])
1566
def test_orthogonal_parametrization(self):
1567
# Orthogonal implements 6 algorithms (3x parametrizations times 2 options of use_trivialization)
1569
def assert_is_orthogonal(X):
1570
n, k = X.size(-2), X.size(-1)
1574
Id = torch.eye(k, dtype=X.dtype, device=X.device).expand(
1575
*(X.size()[:-2]), k, k
1577
eps = 10 * n * torch.finfo(X.dtype).eps
1578
torch.testing.assert_close(X.mH @ X, Id, atol=eps, rtol=0.0)
1580
def assert_weight_allclose_Q(weight, W):
1581
# Test that weight is equal to the Q part of the QR decomposition of W
1582
# (or of its transpose if the matrix is wide)
1583
wide_matrix = W.size(-2) < W.size(-1)
1586
Q, R = torch.linalg.qr(W)
1587
Q *= R.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2)
1590
torch.testing.assert_close(Q, weight, atol=1e-5, rtol=0.0)
1592
for shape, dtype, use_linear in product(
1593
((4, 4), (5, 3), (3, 5)), # square/ tall / wide
1594
(torch.float32, torch.complex64),
1597
# Conv2d does not support complex yet
1602
input = torch.randn(3, shape[0], dtype=dtype)
1604
input = torch.randn(2, 2, shape[0] + 2, shape[1] + 1, dtype=dtype)
1606
for parametrization, use_trivialization in product(
1607
("matrix_exp", "cayley", "householder"), (False, True)
1609
# right_inverse for Cayley and matrix_exp not implemented for use_trivialization=False
1610
# See Note [right_inverse expm cayley]
1611
can_initialize = use_trivialization or parametrization == "householder"
1613
# We generate them every time to always start with fresh weights
1615
m = nn.Linear(*shape, dtype=dtype)
1617
m = nn.Conv2d(2, 3, shape, dtype=dtype)
1619
# We do not support householder for complex inputs
1620
# See Note [Householder complex]
1622
# When using the swap_tensors path, this is needed so that the autograd
1623
# graph is not alive anymore.
1624
if get_swap_module_params_on_conversion():
1625
w_init = m.weight.clone().detach()
1627
w_init = m.weight.clone()
1628
if parametrization == "householder" and m.weight.is_complex():
1629
msg = "householder parametrization does not support complex tensors"
1630
with self.assertRaisesRegex(ValueError, msg):
1631
torch.nn.utils.parametrizations.orthogonal(
1635
use_trivialization=use_trivialization,
1639
wide_matrix = w_init.size(-2) < w_init.size(-1)
1640
torch.nn.utils.parametrizations.orthogonal(
1641
m, "weight", parametrization, use_trivialization=use_trivialization
1643
# Forwards works as expected
1644
self.assertEqual(w_init.shape, m.weight.shape)
1645
assert_is_orthogonal(m.weight)
1647
assert_weight_allclose_Q(m.weight, w_init)
1649
# Intializing with a given orthogonal matrix works
1650
X = torch.randn_like(m.weight)
1653
w_new = torch.linalg.qr(X).Q
1658
torch.testing.assert_close(w_new, m.weight, atol=1e-5, rtol=0.0)
1661
"assign to the matrix exponential or the Cayley parametrization"
1663
with self.assertRaisesRegex(NotImplementedError, msg):
1666
# Intializing with a non-orthogonal matrix makes m.weight be the Q part of the given matrix
1667
w_new = torch.randn_like(m.weight)
1670
assert_weight_allclose_Q(m.weight, w_new)
1673
"assign to the matrix exponential or the Cayley parametrization"
1675
with self.assertRaisesRegex(NotImplementedError, msg):
1678
opt = torch.optim.SGD(m.parameters(), lr=0.1)
1681
m(input).norm().backward()
1682
grad = m.parametrizations.weight.original.grad
1683
self.assertIsNotNone(grad)
1684
# We do not update the upper triangular part of the matrix if tall tril if wide
1685
if grad.size(-2) >= grad.size(-1):
1686
zeros_grad = grad.triu(1)
1688
zeros_grad = grad.tril(-1)
1689
self.assertEqual(zeros_grad, torch.zeros_like(zeros_grad))
1690
# The gradient in the diagonal can only be imaginary because a skew-Hermitian
1691
# matrix has imaginary diagonal
1692
diag_grad = grad.diagonal(dim1=-2, dim2=-1)
1693
if grad.is_complex():
1694
diag_grad = diag_grad.real
1695
self.assertEqual(diag_grad, torch.zeros_like(diag_grad))
1697
assert_is_orthogonal(m.weight)
1700
@swap([True, False])
1701
def test_orthogonal_errors(self):
1703
with self.assertRaisesRegex(ValueError, "has to be one of"):
1704
torch.nn.utils.parametrizations.orthogonal(m, "weight", "foo")
1706
with self.assertRaisesRegex(ValueError, "Expected a matrix"):
1707
torch.nn.utils.parametrizations.orthogonal(m, "bias")
1709
torch.nn.utils.parametrizations.orthogonal(m, "weight")
1710
with self.assertRaisesRegex(ValueError, "matrices of shape"):
1711
m.weight = torch.randn(5, 5)
1712
torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
1714
@swap([True, False])
1715
def test_weight_norm_state_dict_compat(self):
1717
m = torch.nn.utils.weight_norm(m)
1718
old_dict = m.state_dict()
1720
m2 = nn.Linear(4, 5)
1721
m2 = torch.nn.utils.parametrizations.weight_norm(m2)
1722
m2.load_state_dict(old_dict)
1724
input = torch.randn(3, 4)
1725
self.assertEqual(m(input), m2(input))
1727
@swap([True, False])
1728
def test_weight_norm_pickle(self):
1730
m = torch.nn.utils.parametrizations.weight_norm(m)
1731
with self.assertRaisesRegex(RuntimeError, "state_dict"):
1734
@swap([True, False])
1735
def test_weight_norm_deepcopy(self):
1737
m = torch.nn.utils.parametrizations.weight_norm(m)
1739
input = torch.randn(3, 4)
1740
self.assertEqual(m(input), m2(input))
1743
def test_wrapper_subclass_parametrization(self):
1744
class Subclassify(nn.Module):
1745
def forward(self, X):
1746
return TwoTensor(X, X)
1748
class UnSubclassify(nn.Module):
1749
def forward(self, X):
1752
class IdentityWithRightInverse(nn.Module):
1753
def forward(self, X):
1756
def right_inverse(self, X):
1757
return TwoTensor(X, X)
1759
def _check_parametrization(
1761
type_before_registration,
1762
type_after_registration,
1763
leave_parametrized=False,
1764
type_after_right_inverse=None,
1766
model = nn.Linear(2, 2)
1767
buf = torch.randn(2, 2)
1768
model.buf = torch.nn.Buffer(buf)
1770
type_before_registration == TwoTensor
1771
and type_after_registration == Tensor
1773
model._apply(lambda t: TwoTensor(t, t))
1774
initial_weight = model.weight.clone().detach()
1775
initial_weight_id = id(model.weight)
1776
initial_buf = model.buf.clone().detach()
1777
initial_buf_id = id(model.buf)
1778
type_original_weight = (
1779
type_before_registration
1780
if type_after_right_inverse is None
1781
else type_after_right_inverse
1783
type_original_buf = (
1784
Tensor if type_original_weight is nn.Parameter else type_original_weight
1786
type_after_removal_buf = (
1787
type_after_registration if leave_parametrized else type_original_buf
1789
if leave_parametrized:
1790
if type_after_registration is Tensor:
1791
type_after_removal_weight = nn.Parameter
1793
type_after_removal_weight = type_after_registration
1795
type_after_removal_weight = type_original_weight
1797
parametrize.register_parametrization(model, "weight", parametrization())
1798
parametrize.register_parametrization(model, "buf", parametrization())
1799
self.assertTrue(hasattr(model, "parametrizations"))
1800
self.assertTrue(parametrize.is_parametrized(model))
1801
self.assertFalse(parametrize.is_parametrized(model, "bias"))
1803
self.assertTrue(parametrize.is_parametrized(model, "weight"))
1805
isinstance(model.parametrizations.weight.original, nn.Parameter)
1808
type(model.parametrizations.weight.original) is type_original_weight
1810
self.assertNotIn("weight", model._parameters)
1811
self.assertTrue(type(model.weight) is type_after_registration)
1813
self.assertTrue(parametrize.is_parametrized(model, "buf"))
1815
isinstance(model.parametrizations.buf.original, nn.Parameter)
1818
type(model.parametrizations.buf.original) is type_original_buf
1820
self.assertTrue(type(model.buf) is type_after_registration)
1821
parametrize.remove_parametrizations(
1822
model, "weight", leave_parametrized=leave_parametrized
1824
parametrize.remove_parametrizations(
1825
model, "buf", leave_parametrized=leave_parametrized
1827
self.assertFalse(hasattr(model, "parametrizations"))
1828
self.assertEqual(model.__class__, nn.Linear)
1830
self.assertTrue(type(model.weight) is type_after_removal_weight)
1831
self.assertTrue(isinstance(model.weight, nn.Parameter))
1832
self.assertEqual(id(model.weight), initial_weight_id)
1834
self.assertTrue(type(model.buf) is type_after_removal_buf)
1835
self.assertFalse(isinstance(model.buf, nn.Parameter))
1836
self.assertEqual(id(model.buf), initial_buf_id)
1837
if not leave_parametrized and type_after_right_inverse is None:
1838
self.assertEqual(model.weight, initial_weight)
1839
self.assertEqual(model.buf, initial_buf)
1841
_check_parametrization(Subclassify, nn.Parameter, TwoTensor)
1842
_check_parametrization(UnSubclassify, TwoTensor, Tensor)
1843
_check_parametrization(
1844
IdentityWithRightInverse,
1847
type_after_right_inverse=TwoTensor,
1849
_check_parametrization(
1850
Subclassify, nn.Parameter, TwoTensor, leave_parametrized=True
1852
_check_parametrization(
1853
UnSubclassify, TwoTensor, Tensor, leave_parametrized=True
1855
_check_parametrization(
1856
IdentityWithRightInverse,
1859
leave_parametrized=True,
1860
type_after_right_inverse=TwoTensor,
1864
class TestNNParametrizationDevice(NNTestCase):
1865
@swap([True, False])
1866
def test_weight_norm_parametrization(self, device):
1867
for dtype in [torch.float, torch.bfloat16]:
1868
input = torch.randn(3, 4, dtype=dtype, device=device)
1869
m = nn.Linear(4, 5, dtype=dtype, device=device)
1870
expected_output = m(input)
1872
# add weight normalization
1873
m = torch.nn.utils.parametrizations.weight_norm(m)
1875
m.parametrizations.weight.original1.size(), m.weight.size()
1877
self.assertEqual(m.parametrizations.weight.original0.size(), (5, 1))
1878
self.assertEqual(m(input), expected_output)
1880
# remove weight norm
1881
torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
1882
self.assertFalse(hasattr(m, "parametrizations"))
1883
self.assertEqual(m(input), expected_output)
1886
m = torch.nn.utils.parametrizations.weight_norm(m, dim=1)
1888
m.parametrizations.weight.original1.size(), m.weight.size()
1890
self.assertEqual(m.parametrizations.weight.original0.size(), (1, 4))
1891
self.assertEqual(m(input), expected_output)
1893
# test with dim=None
1894
m = nn.Linear(4, 5, dtype=dtype, device=device)
1895
expected_output = m(input)
1896
m = torch.nn.utils.parametrizations.weight_norm(m, dim=None)
1897
self.assertEqual(m(input), expected_output)
1900
only_for = ("cpu", "cuda")
1901
instantiate_device_type_tests(TestNNParametrizationDevice, globals(), only_for=only_for)
1902
instantiate_parametrized_tests(TestNNParametrization)
1904
if __name__ == "__main__":