1
# Owner(s): ["module: nn"]
11
import torch.nn.utils.stateless as stateless
12
from torch.testing._internal.common_cuda import TEST_MULTIGPU
13
from torch.testing._internal.common_utils import run_tests, TestCase, parametrize, instantiate_parametrized_tests, \
17
class MockModule(torch.nn.Module):
18
def __init__(self) -> None:
20
self.l1 = torch.nn.Linear(1, 1)
21
self.buffer = torch.nn.Buffer(torch.ones(1))
25
return self.l1(x) + self.buffer
28
class MockTiedModule(torch.nn.Module):
29
def __init__(self) -> None:
31
self.l1 = torch.nn.Linear(1, 1)
32
self.tied_bias = self.l1.bias
33
self.buffer = torch.nn.Buffer(torch.ones(1))
34
self.tied_buffer = self.buffer
37
return self.l1(x) + self.tied_bias + self.buffer + self.tied_buffer
40
class TestStatelessFunctionalAPI(TestCase):
41
def _run_call_with_mock_module(self, module, functional_call, device='cpu', prefix=''):
43
x = torch.rand((1, 1)).to(device)
44
weight = torch.tensor([[1.0]], device=device)
45
bias = torch.tensor([0.0], device=device)
46
buffer = torch.tensor([0.0], device=device)
48
parameters = {f'{prefix}.l1.weight': weight,
49
f'{prefix}.l1.bias': bias,
50
f'{prefix}.buffer': buffer}
52
parameters = {'l1.weight': weight,
57
to_check = getattr(module, prefix)
58
prev_weight = to_check.l1.weight.clone()
59
prev_buffer = to_check.buffer.clone()
60
# the parameters represent an identity function contrary to the
61
# existing params in module. So here we expect the result to be the
62
# same as the input if the weight swapping went well.
63
res = functional_call(module, parameters, x)
64
self.assertEqual(x, res)
65
# check that the weight remain unmodified
66
cur_weight = to_check.l1.weight
67
cur_buffer = to_check.buffer
68
self.assertEqual(cur_weight, prev_weight)
69
self.assertEqual(cur_buffer, prev_buffer)
71
@contextlib.contextmanager
72
def _ensure_module_unchanged(self, module, message):
73
orig_parameters, orig_buffers = tuple(module.parameters()), tuple(module.buffers())
74
orig_tensors = orig_parameters + orig_buffers
75
orig_tensors_values = tuple(t.clone() for t in orig_tensors)
79
parameters, buffers = tuple(module.parameters()), tuple(module.buffers())
81
len(parameters) == len(orig_parameters)
82
and len(buffers) == len(orig_buffers)
84
t1 is t2 and torch.allclose(t1, t3)
85
for t1, t2, t3 in zip(
94
@parametrize("functional_call", [
95
subtest(torch.func.functional_call, "torch_func"),
96
subtest(stateless.functional_call, "stateless")
98
def test_functional_call(self, functional_call):
100
self._run_call_with_mock_module(module, functional_call)
102
@parametrize("functional_call", [
103
subtest(torch.func.functional_call, "torch_func"),
104
subtest(stateless.functional_call, "stateless")
106
def test_functional_call_with_jit(self, functional_call):
107
module = MockModule()
108
jit_module = torch.jit.script(module)
109
with self.assertRaisesRegex(
111
r'used with Jitted modules'
113
self._run_call_with_mock_module(jit_module, functional_call)
114
x = torch.rand((1, 1))
115
traced_module = torch.jit.trace(module, x)
116
with self.assertRaisesRegex(
118
r'used with Jitted modules'
120
self._run_call_with_mock_module(traced_module, functional_call)
122
@unittest.skipIf(not TEST_MULTIGPU, 'multi-GPU not supported')
123
@unittest.skip("This doesn't work right now")
124
@parametrize("functional_call", [
125
subtest(torch.func.functional_call, "torch_func"),
126
subtest(stateless.functional_call, "stateless")
128
def test_functional_call_with_data_parallel(self, functional_call):
129
module = MockModule()
131
dp_module = torch.nn.DataParallel(module, [0, 1])
132
self._run_call_with_mock_module(dp_module, functional_call, device='cuda', prefix='module')
134
@unittest.skipIf(not TEST_MULTIGPU, 'multi-GPU not supported')
135
@parametrize("functional_call", [
136
subtest(torch.func.functional_call, "torch_func"),
137
subtest(stateless.functional_call, "stateless")
139
def test_functional_call_with_data_parallel_error(self, functional_call):
140
module = MockModule()
142
dp_module = torch.nn.DataParallel(module, [0, 1])
143
with self.assertRaisesRegex(RuntimeError, r'used with nn.DataParallel module'):
146
{'module.weight': torch.zeros(5, device='cuda')},
147
(torch.ones(2, 5, device='cuda'),))
149
@parametrize("functional_call", [
150
subtest(torch.func.functional_call, "torch_func"),
151
subtest(stateless.functional_call, "stateless")
153
def test_functional_call_with_gradient(self, functional_call):
154
module = MockModule()
155
x = torch.rand((1, 1))
156
weight = torch.tensor([[1.0]], requires_grad=True)
157
bias = torch.tensor([0.0], requires_grad=True)
158
buffer = torch.tensor([0.0])
159
parameters = {'l1.weight': weight,
162
res = functional_call(module, parameters, x)
163
# Check that a backward step calculates the gradient of the supplied parameters
165
self.assertIsNotNone(weight.grad)
166
self.assertIsNotNone(bias.grad)
167
self.assertIsNone(buffer.grad)
168
# Gradient was not calculated for the module stated and buffers
169
self.assertIsNone(module.l1.weight.grad)
170
self.assertIsNone(module.l1.bias.grad)
171
self.assertIsNone(module.buffer.grad)
173
@parametrize("functional_call", [
174
subtest(torch.func.functional_call, "torch_func"),
175
subtest(stateless.functional_call, "stateless")
177
def test_functional_batch_norm(self, functional_call):
178
module = torch.nn.BatchNorm1d(10)
179
module.train() # Allow stats update
180
# lets replace the running_mean buffer and check if its correctly updated
181
x = torch.full((20, 10), 128.0)
183
parameters = {'running_mean': rm}
184
prev_rm = module.running_mean.clone()
185
res = functional_call(module, parameters, x)
186
cur_rm = module.running_mean
187
self.assertEqual(cur_rm, prev_rm)
188
self.assertEqual(rm, torch.full((10,), 12.8))
189
# Now run functional without reparametrization and check that the module has
191
res = functional_call(module, {}, x)
192
self.assertEqual(module.running_mean, torch.full((10,), 12.8))
194
@parametrize("functional_call", [
195
subtest(torch.func.functional_call, "torch_func"),
196
subtest(stateless.functional_call, "stateless")
198
def test_circular_references(self, functional_call):
199
module = MockModule()
200
# Add a circular reference
202
x = torch.rand((1, 1))
203
weight = torch.tensor([[1.0]])
204
bias = torch.tensor([0.0])
205
buffer = torch.tensor([0.0])
206
parameters = {'l1.m.l1.weight': weight,
208
'l1.m.buffer': buffer}
209
prev_weight = module.l1.weight.clone()
210
prev_buffer = module.buffer.clone()
211
res = functional_call(module, parameters, x, tie_weights=False)
212
self.assertEqual(x, res)
213
# check that the weights remain unmodified and were correctly accesed
214
cur_weight = module.l1.weight
215
cur_buffer = module.buffer
216
self.assertEqual(cur_weight, prev_weight)
217
self.assertEqual(cur_buffer, prev_buffer)
219
@parametrize("functional_call", [
220
subtest(torch.func.functional_call, "torch_func"),
221
subtest(stateless.functional_call, "stateless")
223
def test_reparametrized_module_change_parametrization_original(self, functional_call):
224
module = MockModule()
225
torch.nn.utils.parametrizations.spectral_norm(module.l1)
226
self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters()))
227
orig_sn_weight = module.l1.weight.clone()
228
x = torch.rand((1, 1))
229
# We substitute the parameter inside the parametrization
230
# the parametrization itself is not overwritten so it will be applied with a different
231
# value for the original tensor
232
parameters = {'l1.parametrizations.weight.original': torch.nn.Parameter(torch.tensor([[1.0]])),
233
'l1.bias': torch.tensor([0.0]),
234
'buffer': torch.tensor([0.0])}
235
res = functional_call(module, parameters, x)
236
self.assertEqual(x, res)
237
# verify that the spectral normalization is still applied
238
self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters()))
239
self.assertEqual(orig_sn_weight, module.l1.weight)
241
@parametrize("functional_call", [
242
subtest(torch.func.functional_call, "torch_func"),
243
subtest(stateless.functional_call, "stateless")
245
def test_reparametrize_module_fail_reset_to_original(self, functional_call):
246
module = MockModule()
247
torch.nn.utils.parametrizations.spectral_norm(module.l1)
248
self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters()))
249
orig_sn_weight = module.l1.weight.clone()
250
# We substitute the parameter inside the parametrization
251
# the parametrization itself is not overwritten so it will be applied with a different
252
# value for the original tensor
253
parameters = {'l1.parametrizations.weight.original': torch.nn.Parameter(torch.tensor([[1.0]])),
254
'l1.bias': torch.tensor([0.0]),
255
'buffer': torch.tensor([0.0])}
257
with self.assertRaisesRegex(RuntimeError, "shapes cannot be multiplied"):
258
@torch._dynamo.disable
260
x = torch.rand((4, 5)) # to work, it should be of size (1, 1)
261
functional_call(module, parameters, x) # this call will fail because x is the wrong size
264
# verify that the spectral normalization is still applied
265
self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters()))
266
self.assertEqual(orig_sn_weight, module.l1.weight)
268
@parametrize("functional_call", [
269
subtest(torch.func.functional_call, "torch_func"),
270
subtest(stateless.functional_call, "stateless")
272
def test_reparametrize_some_weights(self, functional_call):
273
module = MockModule()
274
weight = torch.tensor([[2.0]])
275
bias = torch.tensor([5.0])
276
buffer = torch.tensor([3.0])
277
extra = torch.tensor([1.0])
279
parameters = {'l1.weight': weight}
280
x = torch.randn(1, 1)
281
out = functional_call(module, parameters, x)
282
self.assertEqual(out, x * weight + module.l1.bias + module.buffer)
284
parameters = {'l1.weight': weight,
286
x = torch.randn(1, 1)
287
out = functional_call(module, parameters, x)
288
self.assertEqual(out, x * weight + module.l1.bias + module.buffer)
290
@parametrize("functional_call", [
291
subtest(torch.func.functional_call, "torch_func"),
292
subtest(stateless.functional_call, "stateless")
294
def test_reparametrize_strict(self, functional_call):
295
module = MockModule()
296
weight = torch.tensor([[2.0]])
297
bias = torch.tensor([5.0])
298
buffer = torch.tensor([3.0])
299
extra = torch.tensor([1.0])
301
# All weights no error
302
parameters = {'l1.weight': weight,
305
x = torch.randn(1, 1)
306
with self._ensure_module_unchanged(
308
'the module should not have been modified by a successful call',
310
out = functional_call(module, parameters, x, strict=True)
311
self.assertEqual(out, x * weight + bias + buffer)
314
parameters = {'l1.weight': weight}
315
x = torch.randn(1, 1)
316
with self._ensure_module_unchanged(
318
'the module should not have been modified by a failed call',
320
with self.assertRaisesRegex(
322
re.escape("Missing key(s): 'buffer', 'l1.bias'."),
324
out = functional_call(module, parameters, x, strict=True)
327
parameters = {'l1.weight': weight,
331
x = torch.randn(1, 1)
332
with self._ensure_module_unchanged(
334
'the module should not have been modified by a failed call',
336
with self.assertRaisesRegex(
338
re.escape("Unexpected key(s): 'extra'."),
340
out = functional_call(module, parameters, x, strict=True)
342
# Some weights with extra keys
343
parameters = {'l1.weight': weight,
345
x = torch.randn(1, 1)
346
with self._ensure_module_unchanged(
348
'the module should not have been modified by a failed call',
350
with self.assertRaisesRegex(
352
re.escape("Unexpected key(s): 'extra'.") + r'\s+' + re.escape("Missing key(s): 'buffer', 'l1.bias'."),
354
out = functional_call(module, parameters, x, strict=True)
356
@parametrize("functional_call", [
357
subtest(torch.func.functional_call, "torch_func"),
358
subtest(stateless.functional_call, "stateless")
360
def test_reparametrize_special(self, functional_call):
363
return f'<{self.__class__.__name__}>'
365
module = MockModule()
366
weight = torch.tensor([[2.0]])
367
bias = torch.tensor([5.0])
368
buffer = torch.tensor([3.0])
369
non_tensor = NonTensor()
372
parameters = {'l1.weight': weight,
375
x = torch.randn(1, 1)
376
with self._ensure_module_unchanged(
378
'the module should not have been modified by a successful call',
380
out = functional_call(module, parameters, x)
381
self.assertEqual(out, x * weight + buffer)
384
parameters = {'l1.weight': non_tensor}
385
x = torch.randn(1, 1)
386
with self._ensure_module_unchanged(
388
'the module should not have been modified by a failed call',
390
with self.assertRaisesRegex(
392
re.escape("<NonTensor> is not an instance of torch.Tensor"),
394
out = functional_call(module, parameters, x)
396
# Set non-tensor attribute
397
parameters = {'l1.weight': weight, 'foo': torch.tensor([1.0])}
398
x = torch.randn(1, 1)
399
with self._ensure_module_unchanged(
401
'the module should not have been modified by a failed call',
403
with self.assertRaisesRegex(
405
re.escape("attribute `foo`: 0.0 is not an instance of torch.Tensor"),
407
out = functional_call(module, parameters, x)
409
# Set non-exist submodule
410
parameters = {'l1.weight': weight,
412
x = torch.randn(1, 1)
413
with self._ensure_module_unchanged(
415
'the module should not have been modified by a failed call',
417
with self.assertRaisesRegex(
419
re.escape("MockModule has no attribute `l2`"),
421
out = functional_call(module, parameters, x)
423
@parametrize("functional_call", [
424
subtest(torch.func.functional_call, "torch_func"),
425
subtest(stateless.functional_call, "stateless")
427
def test_tied_weights_warns(self, functional_call):
428
module = MockModule()
429
module.tied_bias = module.l1.bias
430
module.tied_buffer = torch.nn.Buffer(module.buffer)
432
@parametrize("functional_call", [
433
subtest(torch.func.functional_call, "torch_func"),
434
subtest(stateless.functional_call, "stateless")
436
def test_reparametrize_tie_weights(self, functional_call):
437
module = MockTiedModule()
438
weight = torch.tensor([[2.0]])
439
bias = torch.tensor([5.0])
440
buffer = torch.tensor([3.0])
441
extra = torch.tensor([1.0])
443
parameters = {'l1.weight': weight,
446
x = torch.randn(1, 1)
447
out = functional_call(module, parameters, x, tie_weights=True)
448
self.assertEqual(out, x * weight + bias + bias + buffer + buffer)
450
parameters = {'l1.weight': weight,
454
x = torch.randn(1, 1)
455
out = functional_call(module, parameters, x, tie_weights=True)
456
self.assertEqual(out, x * weight + bias + bias + buffer + buffer)
458
@parametrize("functional_call", [
459
subtest(torch.func.functional_call, "torch_func"),
460
subtest(stateless.functional_call, "stateless")
462
def test_reparametrize_tie_some_weights(self, functional_call):
463
module = MockTiedModule()
464
weight = torch.tensor([[2.0]])
465
buffer = torch.tensor([3.0])
467
parameters = {'l1.weight': weight,
469
x = torch.randn(1, 1)
470
out = stateless.functional_call(module, parameters, x, tie_weights=True)
471
self.assertEqual(out, x * 2. + module.l1.bias + module.tied_bias + buffer + buffer)
473
@parametrize("functional_call", [
474
subtest(torch.func.functional_call, "torch_func"),
475
subtest(stateless._functional_call, "stateless")
477
def test_tied_weights_errors(self, functional_call):
478
module = MockTiedModule()
479
weight = torch.tensor([[1.0]])
480
bias = torch.tensor([0.0])
481
buffer = torch.tensor([0.0])
483
parameters = {'l1.weight': weight,
486
x = torch.randn(1, 1)
487
self.assertNotWarn(lambda: functional_call(module, parameters, x, tie_weights=True))
489
# if tied values are the same tensors, shouldn't warn
490
parameters['tied_bias'] = bias
491
parameters['tied_buffer'] = buffer
492
self.assertNotWarn(lambda: functional_call(module, parameters, x, tie_weights=True))
493
del parameters['tied_bias']
494
del parameters['tied_buffer']
496
with self.assertRaisesRegex(
498
re.escape("functional_call got multiple values for keys ['l1.bias', 'tied_bias']"),
500
parameters['tied_bias'] = torch.tensor([5.0])
501
functional_call(module, parameters, x, tie_weights=True)
502
del parameters['tied_bias']
504
with self.assertRaisesRegex(
506
re.escape("functional_call got multiple values for keys ['buffer', 'tied_buffer']"),
508
parameters['tied_buffer'] = torch.tensor([5.0])
509
functional_call(module, parameters, x, tie_weights=True)
511
def test_tied_weights_no_error_without_flag(self):
512
module = MockTiedModule()
513
weight = torch.tensor([[1.0]])
514
bias = torch.tensor([0.0])
515
buffer = torch.tensor([0.0])
517
parameters = {'l1.weight': weight,
520
x = torch.randn(1, 1)
521
self.assertNotWarn(lambda: stateless._functional_call(module, parameters, x, tie_weights=False))
522
parameters['tied_bias'] = torch.tensor([5.0])
523
self.assertNotWarn(lambda: stateless._functional_call(module, parameters, x, tie_weights=False))
524
del parameters['tied_bias']
525
parameters['tied_buffer'] = torch.tensor([5.0])
526
self.assertNotWarn(lambda: stateless._functional_call(module, parameters, x, tie_weights=False))
528
@parametrize("functional_call", [
529
subtest(torch.func.functional_call, "torch_func"),
530
subtest(stateless.functional_call, "stateless")
532
def test_reparametrize_tie_weights_strict(self, functional_call):
533
module = MockTiedModule()
534
weight = torch.tensor([[2.0]])
535
bias = torch.tensor([5.0])
536
buffer = torch.tensor([3.0])
537
extra = torch.tensor([1.0])
539
# Tie weights no error
540
parameters = {'l1.weight': weight,
543
x = torch.randn(1, 1)
544
with self._ensure_module_unchanged(
546
'the module should not have been modified by a successful call',
548
out = functional_call(module, parameters, x, tie_weights=True, strict=True)
549
self.assertEqual(out, x * weight + bias + bias + buffer + buffer)
551
# Tie weights without flag
552
parameters = {'l1.weight': weight,
555
x = torch.randn(1, 1)
556
with self._ensure_module_unchanged(
558
'the module should not have been modified by a failed call',
560
with self.assertRaisesRegex(
562
re.escape("Missing key(s): 'tied_bias', 'tied_buffer'."),
564
out = functional_call(module, parameters, x, tie_weights=False, strict=True)
567
parameters = {'l1.weight': weight,
569
x = torch.randn(1, 1)
570
with self._ensure_module_unchanged(
572
'the module should not have been modified by a failed call',
574
with self.assertRaisesRegex(
576
re.escape("Missing key(s): 'l1.bias', 'tied_bias'."),
578
out = stateless.functional_call(module, parameters, x, tie_weights=True, strict=True)
580
# Tie weights with extra keys
581
parameters = {'l1.weight': weight,
585
x = torch.randn(1, 1)
586
with self._ensure_module_unchanged(
588
'the module should not have been modified by a failed call',
590
with self.assertRaisesRegex(
592
re.escape("Unexpected key(s): 'extra'."),
594
out = stateless.functional_call(module, parameters, x, tie_weights=True, strict=True)
596
# Tie weights with extra keys and without flag
597
parameters = {'l1.weight': weight,
601
x = torch.randn(1, 1)
602
with self._ensure_module_unchanged(
604
'the module should not have been modified by a failed call',
606
with self.assertRaisesRegex(
608
re.escape("Unexpected key(s): 'extra'.") + r'\s+' + re.escape("Missing key(s): 'tied_bias', 'tied_buffer'."),
610
out = stateless.functional_call(module, parameters, x, tie_weights=False, strict=True)
612
# Tie some weights with extra keys
613
parameters = {'l1.weight': weight,
616
x = torch.randn(1, 1)
617
with self._ensure_module_unchanged(
619
'the module should not have been modified by a failed call',
621
with self.assertRaisesRegex(
623
re.escape("Unexpected key(s): 'extra'.") + r'\s+' + re.escape("Missing key(s): 'l1.bias', 'tied_bias'."),
625
out = stateless.functional_call(module, parameters, x, tie_weights=True, strict=True)
627
@parametrize("functional_call", [
628
subtest(torch.func.functional_call, "torch_func"),
629
subtest(stateless.functional_call, "stateless")
631
def test_setattr(self, functional_call):
632
class Foo(torch.nn.Module):
633
def __init__(self) -> None:
635
self.foo = torch.nn.Buffer(torch.tensor([0.0]))
637
def forward(self, x):
638
self.foo = self.foo + 1
641
foo = torch.tensor([2.0])
645
functional_call(mod, a, x)
646
self.assertEqual(mod.foo, torch.tensor([0.0]))
647
self.assertEqual(a['foo'], torch.tensor([3.0]))
648
self.assertEqual(foo, torch.tensor([2.0]))
649
self.assertTrue(a['foo'] is not foo)
651
@parametrize("functional_call", [
652
subtest(torch.func.functional_call, "torch_func"),
653
subtest(stateless.functional_call, "stateless")
655
def test_in_place_operator(self, functional_call):
656
class Foo(torch.nn.Module):
657
def __init__(self) -> None:
659
self.foo = torch.nn.Buffer(torch.tensor([0.0]))
661
def forward(self, x):
665
foo = torch.tensor([2.0])
669
functional_call(mod, a, x)
670
self.assertEqual(mod.foo, torch.tensor([0.0]))
671
self.assertEqual(a['foo'], torch.tensor([3.0]))
672
self.assertEqual(foo, torch.tensor([3.0]))
673
self.assertTrue(a['foo'] is foo)
675
@parametrize("functional_call", [
676
subtest(torch.func.functional_call, "torch_func"),
677
subtest(stateless.functional_call, "stateless")
679
def test_setattr_strict(self, functional_call):
680
class Bar(torch.nn.Module):
681
def __init__(self) -> None:
683
assert not hasattr(self, 'extra')
685
def forward(self, x):
686
return x + self.extra
688
a = {'extra': torch.zeros(())}
690
self.assertTrue(not hasattr(mod, 'extra'))
691
out = functional_call(mod, a, torch.ones(()))
692
self.assertEqual(out, torch.ones(()))
693
self.assertTrue(not hasattr(mod, 'extra'))
695
a = {'extra': torch.zeros(())}
696
with self.assertRaisesRegex(
698
re.escape("Unexpected key(s): 'extra'."),
700
out = functional_call(mod, a, torch.ones(()), strict=True)
701
self.assertTrue(not hasattr(mod, 'extra'))
704
with self.assertRaisesRegex(
706
re.escape("'Bar' object has no attribute 'extra'"),
708
out = functional_call(mod, a, torch.ones(()))
709
self.assertTrue(not hasattr(mod, 'extra'))
712
with self.assertRaisesRegex(
714
re.escape("'Bar' object has no attribute 'extra'"),
716
out = functional_call(mod, a, torch.ones(()), strict=True)
717
self.assertTrue(not hasattr(mod, 'extra'))
719
@parametrize("functional_call", [
720
subtest(torch.func.functional_call, "torch_func"),
721
subtest(stateless.functional_call, "stateless")
723
def test_functional_call_with_kwargs(self, functional_call):
724
class Foo(torch.nn.Module):
725
def __init__(self, x):
729
def forward(self, inp, *, other_inp):
730
return inp * self.x + other_inp
732
a = {'x': torch.zeros(2, 3)}
733
mod = Foo(torch.randn(2, 3))
734
inp, other_inp = torch.randn(2, 3), torch.randn(2, 3)
735
with self.assertRaisesRegex(TypeError, "missing 1 required keyword-only argument: 'other_inp'"):
736
functional_call(mod, a, inp)
737
res = functional_call(mod, a, inp, {'other_inp': other_inp})
738
self.assertEqual(res, other_inp)
739
res_1 = functional_call(mod, a, (), {'inp': inp, 'other_inp': other_inp})
740
self.assertEqual(res, res_1)
742
def test_functional_call_tuple_dicts(self):
744
x = torch.rand((1, 1))
745
parameters = {k: torch.ones_like(v) for k, v in mod.named_parameters()}
746
buffers = {k: torch.zeros_like(v) for k, v in mod.named_buffers()}
749
res = torch.func.functional_call(mod, (parameters, buffers), x)
750
self.assertEqual(res, x + 1)
753
res = torch.func.functional_call(mod, (), x)
754
self.assertEqual(res, mod(x))
757
a = ({'l1.weight': torch.ones(1, 1)}, {'l1.bias': torch.ones(1)}, {'buffer': torch.zeros(1)})
758
res = torch.func.functional_call(mod, a, x)
759
self.assertEqual(res, x + 1)
761
def test_functional_call_multiple_dicts_error(self):
763
x = torch.rand((1, 1))
764
parameters = {'l1.weight': torch.zeros((1, 1)), 'l1.bias': torch.zeros((1, 1))}
765
repeated_parameters = {'l1.weight': torch.ones((1, 1))}
766
with self.assertRaisesRegex(
768
re.escape("['l1.weight'] appeared in multiple dictionaries"),
770
torch.func.functional_call(mod, (parameters, repeated_parameters), x)
772
@parametrize("functional_call", [
773
subtest(torch.func.functional_call, "torch_func"),
774
subtest(stateless.functional_call, "stateless")
776
def test_functional_call_member_reference(self, functional_call):
777
class Module(torch.nn.Module):
778
def __init__(self) -> None:
780
self.l1 = torch.nn.Linear(1, 1)
781
self.buffer = torch.nn.Buffer(torch.ones(1))
783
def forward(self, x):
784
parameters = tuple(self.parameters())
785
buffers = tuple(self.buffers())
786
return self.l1(x) + self.buffer, parameters, buffers
789
weight = torch.tensor([[2.0]])
790
bias = torch.tensor([5.0])
791
buffer = torch.tensor([3.0])
792
extra = torch.tensor([1.0])
793
extra_p = torch.nn.Parameter(extra)
796
parameters = {'l1.weight': weight,
799
x = torch.randn(1, 1)
800
out, parameters, buffers = functional_call(module, parameters, x)
801
self.assertEqual(out, x * weight + bias + buffer)
802
self.assertEqual(parameters, (weight, bias))
803
self.assertEqual(buffers, (buffer,))
804
self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, bias))))
805
self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (buffer,))))
808
parameters = {'l1.weight': weight}
809
x = torch.randn(1, 1)
810
out, parameters, buffers = functional_call(module, parameters, x)
811
self.assertEqual(out, x * weight + module.l1.bias + module.buffer)
812
self.assertEqual(parameters, (weight, module.l1.bias))
813
self.assertEqual(buffers, (module.buffer,))
814
self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, module.l1.bias))))
815
self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,))))
817
# All weights with extra keys
818
parameters = {'l1.weight': weight,
822
x = torch.randn(1, 1)
823
out, parameters, buffers = functional_call(module, parameters, x)
824
self.assertEqual(out, x * weight + bias + buffer)
825
self.assertEqual(parameters, (weight, bias))
826
self.assertEqual(buffers, (buffer,))
827
self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, bias))))
828
self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (buffer,))))
830
# All weights with extra keys with parameters
831
parameters = {'l1.weight': weight,
835
x = torch.randn(1, 1)
836
out, parameters, buffers = functional_call(module, parameters, x)
837
self.assertEqual(out, x * weight + bias + buffer)
838
self.assertEqual(parameters, (weight, bias, extra_p))
839
self.assertEqual(buffers, (buffer,))
840
self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, bias, extra_p))))
841
self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (buffer,))))
843
# Some weights with extra keys
844
parameters = {'l1.weight': weight,
846
x = torch.randn(1, 1)
847
out, parameters, buffers = functional_call(module, parameters, x)
848
self.assertEqual(out, x * weight + module.l1.bias + module.buffer)
849
self.assertEqual(parameters, (weight, module.l1.bias))
850
self.assertEqual(buffers, (module.buffer))
851
self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, module.l1.bias))))
852
self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,))))
854
# Some weights with extra keys with parameters
855
parameters = {'l1.weight': weight,
857
x = torch.randn(1, 1)
858
out, parameters, buffers = functional_call(module, parameters, x)
859
self.assertEqual(out, x * weight + module.l1.bias + module.buffer)
860
self.assertEqual(parameters, (weight, module.l1.bias, extra_p))
861
self.assertEqual(buffers, (module.buffer))
862
self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, module.l1.bias, extra_p))))
863
self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,))))
866
parameters = {'l1.weight': weight,
868
x = torch.randn(1, 1)
869
out, parameters, buffers = functional_call(module, parameters, x)
870
self.assertEqual(out, x * weight + module.buffer)
871
self.assertEqual(parameters, (weight,))
872
self.assertEqual(buffers, (module.buffer))
873
self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight,))))
874
self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,))))
877
class TestStatelessDeprecation(TestCase):
878
def test_private_stateless_warns(self):
883
with warnings.catch_warnings(record=True) as w:
884
from torch.nn.utils import _stateless
889
subprocess.check_output(
890
[sys.executable, '-W', 'always', '-c', script],
891
stderr=subprocess.STDOUT,
892
# On Windows, opening the subprocess with the default CWD makes `import torch`
893
# fail, so just set CWD to this script's directory
894
cwd=os.path.dirname(os.path.realpath(__file__)),)
895
except subprocess.CalledProcessError as e:
896
self.assertEqual(e.returncode, 1)
898
self.assertTrue(False, "No warning was raised.")
900
def test_stateless_functional_call_warns(self):
901
m = torch.nn.Linear(1, 1)
902
params = dict(m.named_parameters())
903
x = torch.randn(3, 1)
904
with self.assertWarnsRegex(FutureWarning, "Please use `torch.func.functional_call`"):
905
stateless.functional_call(m, params, x)
907
class TestPythonOptimizeMode(TestCase):
908
def test_runs_with_optimize_flag(self):
909
script = "import torch; import torch._functorch.deprecated"
911
subprocess.check_output(
912
[sys.executable, "-OO", "-c", script],
913
stderr=subprocess.STDOUT,
914
# On Windows, opening the subprocess with the default CWD makes `import torch`
915
# fail, so just set CWD to this script's directory
916
cwd=os.path.dirname(os.path.realpath(__file__)),)
917
except subprocess.CalledProcessError as e:
918
self.assertFalse(e.returncode, "Import failed while running python in optimized mode")
921
instantiate_parametrized_tests(
922
TestStatelessFunctionalAPI,
925
if __name__ == '__main__':