pytorch

Форк
0
/
test_stateless.py 
926 строк · 37.1 Кб
1
# Owner(s): ["module: nn"]
2

3
import contextlib
4
import os
5
import re
6
import subprocess
7
import sys
8
import unittest
9

10
import torch
11
import torch.nn.utils.stateless as stateless
12
from torch.testing._internal.common_cuda import TEST_MULTIGPU
13
from torch.testing._internal.common_utils import run_tests, TestCase, parametrize, instantiate_parametrized_tests, \
14
    subtest
15

16

17
class MockModule(torch.nn.Module):
18
    def __init__(self) -> None:
19
        super().__init__()
20
        self.l1 = torch.nn.Linear(1, 1)
21
        self.buffer = torch.nn.Buffer(torch.ones(1))
22
        self.foo = 0.0
23

24
    def forward(self, x):
25
        return self.l1(x) + self.buffer
26

27

28
class MockTiedModule(torch.nn.Module):
29
    def __init__(self) -> None:
30
        super().__init__()
31
        self.l1 = torch.nn.Linear(1, 1)
32
        self.tied_bias = self.l1.bias
33
        self.buffer = torch.nn.Buffer(torch.ones(1))
34
        self.tied_buffer = self.buffer
35

36
    def forward(self, x):
37
        return self.l1(x) + self.tied_bias + self.buffer + self.tied_buffer
38

39

40
class TestStatelessFunctionalAPI(TestCase):
41
    def _run_call_with_mock_module(self, module, functional_call, device='cpu', prefix=''):
42

43
        x = torch.rand((1, 1)).to(device)
44
        weight = torch.tensor([[1.0]], device=device)
45
        bias = torch.tensor([0.0], device=device)
46
        buffer = torch.tensor([0.0], device=device)
47
        if prefix != '':
48
            parameters = {f'{prefix}.l1.weight': weight,
49
                          f'{prefix}.l1.bias': bias,
50
                          f'{prefix}.buffer': buffer}
51
        else:
52
            parameters = {'l1.weight': weight,
53
                          'l1.bias': bias,
54
                          'buffer': buffer}
55
        to_check = module
56
        if prefix != '':
57
            to_check = getattr(module, prefix)
58
        prev_weight = to_check.l1.weight.clone()
59
        prev_buffer = to_check.buffer.clone()
60
        # the parameters represent an identity function contrary to the
61
        # existing params in module. So here we expect the result to be the
62
        # same as the input if the weight swapping went well.
63
        res = functional_call(module, parameters, x)
64
        self.assertEqual(x, res)
65
        # check that the weight remain unmodified
66
        cur_weight = to_check.l1.weight
67
        cur_buffer = to_check.buffer
68
        self.assertEqual(cur_weight, prev_weight)
69
        self.assertEqual(cur_buffer, prev_buffer)
70

71
    @contextlib.contextmanager
72
    def _ensure_module_unchanged(self, module, message):
73
        orig_parameters, orig_buffers = tuple(module.parameters()), tuple(module.buffers())
74
        orig_tensors = orig_parameters + orig_buffers
75
        orig_tensors_values = tuple(t.clone() for t in orig_tensors)
76
        try:
77
            yield module
78
        finally:
79
            parameters, buffers = tuple(module.parameters()), tuple(module.buffers())
80
            self.assertTrue(
81
                len(parameters) == len(orig_parameters)
82
                and len(buffers) == len(orig_buffers)
83
                and all(
84
                    t1 is t2 and torch.allclose(t1, t3)
85
                    for t1, t2, t3 in zip(
86
                        orig_tensors,
87
                        parameters + buffers,
88
                        orig_tensors_values,
89
                    )
90
                ),
91
                message,
92
            )
93

94
    @parametrize("functional_call", [
95
        subtest(torch.func.functional_call, "torch_func"),
96
        subtest(stateless.functional_call, "stateless")
97
    ])
98
    def test_functional_call(self, functional_call):
99
        module = MockModule()
100
        self._run_call_with_mock_module(module, functional_call)
101

102
    @parametrize("functional_call", [
103
        subtest(torch.func.functional_call, "torch_func"),
104
        subtest(stateless.functional_call, "stateless")
105
    ])
106
    def test_functional_call_with_jit(self, functional_call):
107
        module = MockModule()
108
        jit_module = torch.jit.script(module)
109
        with self.assertRaisesRegex(
110
            RuntimeError,
111
            r'used with Jitted modules'
112
        ):
113
            self._run_call_with_mock_module(jit_module, functional_call)
114
        x = torch.rand((1, 1))
115
        traced_module = torch.jit.trace(module, x)
116
        with self.assertRaisesRegex(
117
            RuntimeError,
118
            r'used with Jitted modules'
119
        ):
120
            self._run_call_with_mock_module(traced_module, functional_call)
121

122
    @unittest.skipIf(not TEST_MULTIGPU, 'multi-GPU not supported')
123
    @unittest.skip("This doesn't work right now")
124
    @parametrize("functional_call", [
125
        subtest(torch.func.functional_call, "torch_func"),
126
        subtest(stateless.functional_call, "stateless")
127
    ])
128
    def test_functional_call_with_data_parallel(self, functional_call):
129
        module = MockModule()
130
        module.cuda()
131
        dp_module = torch.nn.DataParallel(module, [0, 1])
132
        self._run_call_with_mock_module(dp_module, functional_call, device='cuda', prefix='module')
133

134
    @unittest.skipIf(not TEST_MULTIGPU, 'multi-GPU not supported')
135
    @parametrize("functional_call", [
136
        subtest(torch.func.functional_call, "torch_func"),
137
        subtest(stateless.functional_call, "stateless")
138
    ])
139
    def test_functional_call_with_data_parallel_error(self, functional_call):
140
        module = MockModule()
141
        module.cuda()
142
        dp_module = torch.nn.DataParallel(module, [0, 1])
143
        with self.assertRaisesRegex(RuntimeError, r'used with nn.DataParallel module'):
144
            functional_call(
145
                dp_module,
146
                {'module.weight': torch.zeros(5, device='cuda')},
147
                (torch.ones(2, 5, device='cuda'),))
148

149
    @parametrize("functional_call", [
150
        subtest(torch.func.functional_call, "torch_func"),
151
        subtest(stateless.functional_call, "stateless")
152
    ])
153
    def test_functional_call_with_gradient(self, functional_call):
154
        module = MockModule()
155
        x = torch.rand((1, 1))
156
        weight = torch.tensor([[1.0]], requires_grad=True)
157
        bias = torch.tensor([0.0], requires_grad=True)
158
        buffer = torch.tensor([0.0])
159
        parameters = {'l1.weight': weight,
160
                      'l1.bias': bias,
161
                      'buffer': buffer}
162
        res = functional_call(module, parameters, x)
163
        # Check that a backward step calculates the gradient of the supplied parameters
164
        res.backward()
165
        self.assertIsNotNone(weight.grad)
166
        self.assertIsNotNone(bias.grad)
167
        self.assertIsNone(buffer.grad)
168
        # Gradient was not calculated for the module stated and buffers
169
        self.assertIsNone(module.l1.weight.grad)
170
        self.assertIsNone(module.l1.bias.grad)
171
        self.assertIsNone(module.buffer.grad)
172

173
    @parametrize("functional_call", [
174
        subtest(torch.func.functional_call, "torch_func"),
175
        subtest(stateless.functional_call, "stateless")
176
    ])
177
    def test_functional_batch_norm(self, functional_call):
178
        module = torch.nn.BatchNorm1d(10)
179
        module.train()  # Allow stats update
180
        # lets replace the running_mean buffer and check if its correctly updated
181
        x = torch.full((20, 10), 128.0)
182
        rm = torch.zeros(10)
183
        parameters = {'running_mean': rm}
184
        prev_rm = module.running_mean.clone()
185
        res = functional_call(module, parameters, x)
186
        cur_rm = module.running_mean
187
        self.assertEqual(cur_rm, prev_rm)
188
        self.assertEqual(rm, torch.full((10,), 12.8))
189
        # Now run functional without reparametrization and check that the module has
190
        # been updated
191
        res = functional_call(module, {}, x)
192
        self.assertEqual(module.running_mean, torch.full((10,), 12.8))
193

194
    @parametrize("functional_call", [
195
        subtest(torch.func.functional_call, "torch_func"),
196
        subtest(stateless.functional_call, "stateless")
197
    ])
198
    def test_circular_references(self, functional_call):
199
        module = MockModule()
200
        # Add a circular reference
201
        module.l1.m = module
202
        x = torch.rand((1, 1))
203
        weight = torch.tensor([[1.0]])
204
        bias = torch.tensor([0.0])
205
        buffer = torch.tensor([0.0])
206
        parameters = {'l1.m.l1.weight': weight,
207
                      'l1.bias': bias,
208
                      'l1.m.buffer': buffer}
209
        prev_weight = module.l1.weight.clone()
210
        prev_buffer = module.buffer.clone()
211
        res = functional_call(module, parameters, x, tie_weights=False)
212
        self.assertEqual(x, res)
213
        # check that the weights remain unmodified and were correctly accesed
214
        cur_weight = module.l1.weight
215
        cur_buffer = module.buffer
216
        self.assertEqual(cur_weight, prev_weight)
217
        self.assertEqual(cur_buffer, prev_buffer)
218

219
    @parametrize("functional_call", [
220
        subtest(torch.func.functional_call, "torch_func"),
221
        subtest(stateless.functional_call, "stateless")
222
    ])
223
    def test_reparametrized_module_change_parametrization_original(self, functional_call):
224
        module = MockModule()
225
        torch.nn.utils.parametrizations.spectral_norm(module.l1)
226
        self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters()))
227
        orig_sn_weight = module.l1.weight.clone()
228
        x = torch.rand((1, 1))
229
        # We substitute the parameter inside the parametrization
230
        # the parametrization itself is not overwritten so it will be applied with a different
231
        # value for the original tensor
232
        parameters = {'l1.parametrizations.weight.original': torch.nn.Parameter(torch.tensor([[1.0]])),
233
                      'l1.bias': torch.tensor([0.0]),
234
                      'buffer': torch.tensor([0.0])}
235
        res = functional_call(module, parameters, x)
236
        self.assertEqual(x, res)
237
        # verify that the spectral normalization is still applied
238
        self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters()))
239
        self.assertEqual(orig_sn_weight, module.l1.weight)
240

241
    @parametrize("functional_call", [
242
        subtest(torch.func.functional_call, "torch_func"),
243
        subtest(stateless.functional_call, "stateless")
244
    ])
245
    def test_reparametrize_module_fail_reset_to_original(self, functional_call):
246
        module = MockModule()
247
        torch.nn.utils.parametrizations.spectral_norm(module.l1)
248
        self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters()))
249
        orig_sn_weight = module.l1.weight.clone()
250
        # We substitute the parameter inside the parametrization
251
        # the parametrization itself is not overwritten so it will be applied with a different
252
        # value for the original tensor
253
        parameters = {'l1.parametrizations.weight.original': torch.nn.Parameter(torch.tensor([[1.0]])),
254
                      'l1.bias': torch.tensor([0.0]),
255
                      'buffer': torch.tensor([0.0])}
256

257
        with self.assertRaisesRegex(RuntimeError, "shapes cannot be multiplied"):
258
            @torch._dynamo.disable
259
            def _error_case():
260
                x = torch.rand((4, 5))  # to work, it should be of size (1, 1)
261
                functional_call(module, parameters, x)  # this call will fail because x is the wrong size
262
            _error_case()
263

264
        # verify that the spectral normalization is still applied
265
        self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters()))
266
        self.assertEqual(orig_sn_weight, module.l1.weight)
267

268
    @parametrize("functional_call", [
269
        subtest(torch.func.functional_call, "torch_func"),
270
        subtest(stateless.functional_call, "stateless")
271
    ])
272
    def test_reparametrize_some_weights(self, functional_call):
273
        module = MockModule()
274
        weight = torch.tensor([[2.0]])
275
        bias = torch.tensor([5.0])
276
        buffer = torch.tensor([3.0])
277
        extra = torch.tensor([1.0])
278

279
        parameters = {'l1.weight': weight}
280
        x = torch.randn(1, 1)
281
        out = functional_call(module, parameters, x)
282
        self.assertEqual(out, x * weight + module.l1.bias + module.buffer)
283

284
        parameters = {'l1.weight': weight,
285
                      'extra': extra}
286
        x = torch.randn(1, 1)
287
        out = functional_call(module, parameters, x)
288
        self.assertEqual(out, x * weight + module.l1.bias + module.buffer)
289

290
    @parametrize("functional_call", [
291
        subtest(torch.func.functional_call, "torch_func"),
292
        subtest(stateless.functional_call, "stateless")
293
    ])
294
    def test_reparametrize_strict(self, functional_call):
295
        module = MockModule()
296
        weight = torch.tensor([[2.0]])
297
        bias = torch.tensor([5.0])
298
        buffer = torch.tensor([3.0])
299
        extra = torch.tensor([1.0])
300

301
        # All weights no error
302
        parameters = {'l1.weight': weight,
303
                      'l1.bias': bias,
304
                      'buffer': buffer}
305
        x = torch.randn(1, 1)
306
        with self._ensure_module_unchanged(
307
            module,
308
            'the module should not have been modified by a successful call',
309
        ):
310
            out = functional_call(module, parameters, x, strict=True)
311
            self.assertEqual(out, x * weight + bias + buffer)
312

313
        # Some weights
314
        parameters = {'l1.weight': weight}
315
        x = torch.randn(1, 1)
316
        with self._ensure_module_unchanged(
317
            module,
318
            'the module should not have been modified by a failed call',
319
        ):
320
            with self.assertRaisesRegex(
321
                RuntimeError,
322
                re.escape("Missing key(s): 'buffer', 'l1.bias'."),
323
            ):
324
                out = functional_call(module, parameters, x, strict=True)
325

326
        # Extra keys
327
        parameters = {'l1.weight': weight,
328
                      'l1.bias': bias,
329
                      'buffer': buffer,
330
                      'extra': extra}
331
        x = torch.randn(1, 1)
332
        with self._ensure_module_unchanged(
333
            module,
334
            'the module should not have been modified by a failed call',
335
        ):
336
            with self.assertRaisesRegex(
337
                RuntimeError,
338
                re.escape("Unexpected key(s): 'extra'."),
339
            ):
340
                out = functional_call(module, parameters, x, strict=True)
341

342
        # Some weights with extra keys
343
        parameters = {'l1.weight': weight,
344
                      'extra': extra}
345
        x = torch.randn(1, 1)
346
        with self._ensure_module_unchanged(
347
            module,
348
            'the module should not have been modified by a failed call',
349
        ):
350
            with self.assertRaisesRegex(
351
                RuntimeError,
352
                re.escape("Unexpected key(s): 'extra'.") + r'\s+' + re.escape("Missing key(s): 'buffer', 'l1.bias'."),
353
            ):
354
                out = functional_call(module, parameters, x, strict=True)
355

356
    @parametrize("functional_call", [
357
        subtest(torch.func.functional_call, "torch_func"),
358
        subtest(stateless.functional_call, "stateless")
359
    ])
360
    def test_reparametrize_special(self, functional_call):
361
        class NonTensor:
362
            def __repr__(self):
363
                return f'<{self.__class__.__name__}>'
364

365
        module = MockModule()
366
        weight = torch.tensor([[2.0]])
367
        bias = torch.tensor([5.0])
368
        buffer = torch.tensor([3.0])
369
        non_tensor = NonTensor()
370

371
        # Set to None
372
        parameters = {'l1.weight': weight,
373
                      'l1.bias': None,
374
                      'buffer': buffer}
375
        x = torch.randn(1, 1)
376
        with self._ensure_module_unchanged(
377
            module,
378
            'the module should not have been modified by a successful call',
379
        ):
380
            out = functional_call(module, parameters, x)
381
            self.assertEqual(out, x * weight + buffer)
382

383
        # Set non-tensor
384
        parameters = {'l1.weight': non_tensor}
385
        x = torch.randn(1, 1)
386
        with self._ensure_module_unchanged(
387
            module,
388
            'the module should not have been modified by a failed call',
389
        ):
390
            with self.assertRaisesRegex(
391
                TypeError,
392
                re.escape("<NonTensor> is not an instance of torch.Tensor"),
393
            ):
394
                out = functional_call(module, parameters, x)
395

396
        # Set non-tensor attribute
397
        parameters = {'l1.weight': weight, 'foo': torch.tensor([1.0])}
398
        x = torch.randn(1, 1)
399
        with self._ensure_module_unchanged(
400
            module,
401
            'the module should not have been modified by a failed call',
402
        ):
403
            with self.assertRaisesRegex(
404
                TypeError,
405
                re.escape("attribute `foo`: 0.0 is not an instance of torch.Tensor"),
406
            ):
407
                out = functional_call(module, parameters, x)
408

409
        # Set non-exist submodule
410
        parameters = {'l1.weight': weight,
411
                      'l2.bias': bias}
412
        x = torch.randn(1, 1)
413
        with self._ensure_module_unchanged(
414
            module,
415
            'the module should not have been modified by a failed call',
416
        ):
417
            with self.assertRaisesRegex(
418
                AttributeError,
419
                re.escape("MockModule has no attribute `l2`"),
420
            ):
421
                out = functional_call(module, parameters, x)
422

423
    @parametrize("functional_call", [
424
        subtest(torch.func.functional_call, "torch_func"),
425
        subtest(stateless.functional_call, "stateless")
426
    ])
427
    def test_tied_weights_warns(self, functional_call):
428
        module = MockModule()
429
        module.tied_bias = module.l1.bias
430
        module.tied_buffer = torch.nn.Buffer(module.buffer)
431

432
    @parametrize("functional_call", [
433
        subtest(torch.func.functional_call, "torch_func"),
434
        subtest(stateless.functional_call, "stateless")
435
    ])
436
    def test_reparametrize_tie_weights(self, functional_call):
437
        module = MockTiedModule()
438
        weight = torch.tensor([[2.0]])
439
        bias = torch.tensor([5.0])
440
        buffer = torch.tensor([3.0])
441
        extra = torch.tensor([1.0])
442

443
        parameters = {'l1.weight': weight,
444
                      'l1.bias': bias,
445
                      'buffer': buffer}
446
        x = torch.randn(1, 1)
447
        out = functional_call(module, parameters, x, tie_weights=True)
448
        self.assertEqual(out, x * weight + bias + bias + buffer + buffer)
449

450
        parameters = {'l1.weight': weight,
451
                      'l1.bias': bias,
452
                      'buffer': buffer,
453
                      'extra': extra}
454
        x = torch.randn(1, 1)
455
        out = functional_call(module, parameters, x, tie_weights=True)
456
        self.assertEqual(out, x * weight + bias + bias + buffer + buffer)
457

458
    @parametrize("functional_call", [
459
        subtest(torch.func.functional_call, "torch_func"),
460
        subtest(stateless.functional_call, "stateless")
461
    ])
462
    def test_reparametrize_tie_some_weights(self, functional_call):
463
        module = MockTiedModule()
464
        weight = torch.tensor([[2.0]])
465
        buffer = torch.tensor([3.0])
466

467
        parameters = {'l1.weight': weight,
468
                      'buffer': buffer}
469
        x = torch.randn(1, 1)
470
        out = stateless.functional_call(module, parameters, x, tie_weights=True)
471
        self.assertEqual(out, x * 2. + module.l1.bias + module.tied_bias + buffer + buffer)
472

473
    @parametrize("functional_call", [
474
        subtest(torch.func.functional_call, "torch_func"),
475
        subtest(stateless._functional_call, "stateless")
476
    ])
477
    def test_tied_weights_errors(self, functional_call):
478
        module = MockTiedModule()
479
        weight = torch.tensor([[1.0]])
480
        bias = torch.tensor([0.0])
481
        buffer = torch.tensor([0.0])
482

483
        parameters = {'l1.weight': weight,
484
                      'l1.bias': bias,
485
                      'buffer': buffer}
486
        x = torch.randn(1, 1)
487
        self.assertNotWarn(lambda: functional_call(module, parameters, x, tie_weights=True))
488

489
        # if tied values are the same tensors, shouldn't warn
490
        parameters['tied_bias'] = bias
491
        parameters['tied_buffer'] = buffer
492
        self.assertNotWarn(lambda: functional_call(module, parameters, x, tie_weights=True))
493
        del parameters['tied_bias']
494
        del parameters['tied_buffer']
495

496
        with self.assertRaisesRegex(
497
            ValueError,
498
            re.escape("functional_call got multiple values for keys ['l1.bias', 'tied_bias']"),
499
        ):
500
            parameters['tied_bias'] = torch.tensor([5.0])
501
            functional_call(module, parameters, x, tie_weights=True)
502
        del parameters['tied_bias']
503

504
        with self.assertRaisesRegex(
505
            ValueError,
506
            re.escape("functional_call got multiple values for keys ['buffer', 'tied_buffer']"),
507
        ):
508
            parameters['tied_buffer'] = torch.tensor([5.0])
509
            functional_call(module, parameters, x, tie_weights=True)
510

511
    def test_tied_weights_no_error_without_flag(self):
512
        module = MockTiedModule()
513
        weight = torch.tensor([[1.0]])
514
        bias = torch.tensor([0.0])
515
        buffer = torch.tensor([0.0])
516

517
        parameters = {'l1.weight': weight,
518
                      'l1.bias': bias,
519
                      'buffer': buffer}
520
        x = torch.randn(1, 1)
521
        self.assertNotWarn(lambda: stateless._functional_call(module, parameters, x, tie_weights=False))
522
        parameters['tied_bias'] = torch.tensor([5.0])
523
        self.assertNotWarn(lambda: stateless._functional_call(module, parameters, x, tie_weights=False))
524
        del parameters['tied_bias']
525
        parameters['tied_buffer'] = torch.tensor([5.0])
526
        self.assertNotWarn(lambda: stateless._functional_call(module, parameters, x, tie_weights=False))
527

528
    @parametrize("functional_call", [
529
        subtest(torch.func.functional_call, "torch_func"),
530
        subtest(stateless.functional_call, "stateless")
531
    ])
532
    def test_reparametrize_tie_weights_strict(self, functional_call):
533
        module = MockTiedModule()
534
        weight = torch.tensor([[2.0]])
535
        bias = torch.tensor([5.0])
536
        buffer = torch.tensor([3.0])
537
        extra = torch.tensor([1.0])
538

539
        # Tie weights no error
540
        parameters = {'l1.weight': weight,
541
                      'l1.bias': bias,
542
                      'buffer': buffer}
543
        x = torch.randn(1, 1)
544
        with self._ensure_module_unchanged(
545
            module,
546
            'the module should not have been modified by a successful call',
547
        ):
548
            out = functional_call(module, parameters, x, tie_weights=True, strict=True)
549
            self.assertEqual(out, x * weight + bias + bias + buffer + buffer)
550

551
        # Tie weights without flag
552
        parameters = {'l1.weight': weight,
553
                      'l1.bias': bias,
554
                      'buffer': buffer}
555
        x = torch.randn(1, 1)
556
        with self._ensure_module_unchanged(
557
            module,
558
            'the module should not have been modified by a failed call',
559
        ):
560
            with self.assertRaisesRegex(
561
                RuntimeError,
562
                re.escape("Missing key(s): 'tied_bias', 'tied_buffer'."),
563
            ):
564
                out = functional_call(module, parameters, x, tie_weights=False, strict=True)
565

566
        # Tie some weights
567
        parameters = {'l1.weight': weight,
568
                      'buffer': buffer}
569
        x = torch.randn(1, 1)
570
        with self._ensure_module_unchanged(
571
            module,
572
            'the module should not have been modified by a failed call',
573
        ):
574
            with self.assertRaisesRegex(
575
                RuntimeError,
576
                re.escape("Missing key(s): 'l1.bias', 'tied_bias'."),
577
            ):
578
                out = stateless.functional_call(module, parameters, x, tie_weights=True, strict=True)
579

580
        # Tie weights with extra keys
581
        parameters = {'l1.weight': weight,
582
                      'l1.bias': bias,
583
                      'buffer': buffer,
584
                      'extra': extra}
585
        x = torch.randn(1, 1)
586
        with self._ensure_module_unchanged(
587
            module,
588
            'the module should not have been modified by a failed call',
589
        ):
590
            with self.assertRaisesRegex(
591
                RuntimeError,
592
                re.escape("Unexpected key(s): 'extra'."),
593
            ):
594
                out = stateless.functional_call(module, parameters, x, tie_weights=True, strict=True)
595

596
        # Tie weights with extra keys and without flag
597
        parameters = {'l1.weight': weight,
598
                      'l1.bias': bias,
599
                      'buffer': buffer,
600
                      'extra': extra}
601
        x = torch.randn(1, 1)
602
        with self._ensure_module_unchanged(
603
            module,
604
            'the module should not have been modified by a failed call',
605
        ):
606
            with self.assertRaisesRegex(
607
                RuntimeError,
608
                re.escape("Unexpected key(s): 'extra'.") + r'\s+' + re.escape("Missing key(s): 'tied_bias', 'tied_buffer'."),
609
            ):
610
                out = stateless.functional_call(module, parameters, x, tie_weights=False, strict=True)
611

612
        # Tie some weights with extra keys
613
        parameters = {'l1.weight': weight,
614
                      'buffer': buffer,
615
                      'extra': extra}
616
        x = torch.randn(1, 1)
617
        with self._ensure_module_unchanged(
618
            module,
619
            'the module should not have been modified by a failed call',
620
        ):
621
            with self.assertRaisesRegex(
622
                RuntimeError,
623
                re.escape("Unexpected key(s): 'extra'.") + r'\s+' + re.escape("Missing key(s): 'l1.bias', 'tied_bias'."),
624
            ):
625
                out = stateless.functional_call(module, parameters, x, tie_weights=True, strict=True)
626

627
    @parametrize("functional_call", [
628
        subtest(torch.func.functional_call, "torch_func"),
629
        subtest(stateless.functional_call, "stateless")
630
    ])
631
    def test_setattr(self, functional_call):
632
        class Foo(torch.nn.Module):
633
            def __init__(self) -> None:
634
                super().__init__()
635
                self.foo = torch.nn.Buffer(torch.tensor([0.0]))
636

637
            def forward(self, x):
638
                self.foo = self.foo + 1
639
                return x + self.foo
640

641
        foo = torch.tensor([2.0])
642
        x = torch.randn(1)
643
        a = {'foo': foo}
644
        mod = Foo()
645
        functional_call(mod, a, x)
646
        self.assertEqual(mod.foo, torch.tensor([0.0]))
647
        self.assertEqual(a['foo'], torch.tensor([3.0]))
648
        self.assertEqual(foo, torch.tensor([2.0]))
649
        self.assertTrue(a['foo'] is not foo)
650

651
    @parametrize("functional_call", [
652
        subtest(torch.func.functional_call, "torch_func"),
653
        subtest(stateless.functional_call, "stateless")
654
    ])
655
    def test_in_place_operator(self, functional_call):
656
        class Foo(torch.nn.Module):
657
            def __init__(self) -> None:
658
                super().__init__()
659
                self.foo = torch.nn.Buffer(torch.tensor([0.0]))
660

661
            def forward(self, x):
662
                self.foo.add_(1)
663
                return x + self.foo
664

665
        foo = torch.tensor([2.0])
666
        x = torch.randn(1)
667
        a = {'foo': foo}
668
        mod = Foo()
669
        functional_call(mod, a, x)
670
        self.assertEqual(mod.foo, torch.tensor([0.0]))
671
        self.assertEqual(a['foo'], torch.tensor([3.0]))
672
        self.assertEqual(foo, torch.tensor([3.0]))
673
        self.assertTrue(a['foo'] is foo)
674

675
    @parametrize("functional_call", [
676
        subtest(torch.func.functional_call, "torch_func"),
677
        subtest(stateless.functional_call, "stateless")
678
    ])
679
    def test_setattr_strict(self, functional_call):
680
        class Bar(torch.nn.Module):
681
            def __init__(self) -> None:
682
                super().__init__()
683
                assert not hasattr(self, 'extra')
684

685
            def forward(self, x):
686
                return x + self.extra
687

688
        a = {'extra': torch.zeros(())}
689
        mod = Bar()
690
        self.assertTrue(not hasattr(mod, 'extra'))
691
        out = functional_call(mod, a, torch.ones(()))
692
        self.assertEqual(out, torch.ones(()))
693
        self.assertTrue(not hasattr(mod, 'extra'))
694

695
        a = {'extra': torch.zeros(())}
696
        with self.assertRaisesRegex(
697
            RuntimeError,
698
            re.escape("Unexpected key(s): 'extra'."),
699
        ):
700
            out = functional_call(mod, a, torch.ones(()), strict=True)
701
        self.assertTrue(not hasattr(mod, 'extra'))
702

703
        a = {}
704
        with self.assertRaisesRegex(
705
            AttributeError,
706
            re.escape("'Bar' object has no attribute 'extra'"),
707
        ):
708
            out = functional_call(mod, a, torch.ones(()))
709
        self.assertTrue(not hasattr(mod, 'extra'))
710

711
        a = {}
712
        with self.assertRaisesRegex(
713
            AttributeError,
714
            re.escape("'Bar' object has no attribute 'extra'"),
715
        ):
716
            out = functional_call(mod, a, torch.ones(()), strict=True)
717
        self.assertTrue(not hasattr(mod, 'extra'))
718

719
    @parametrize("functional_call", [
720
        subtest(torch.func.functional_call, "torch_func"),
721
        subtest(stateless.functional_call, "stateless")
722
    ])
723
    def test_functional_call_with_kwargs(self, functional_call):
724
        class Foo(torch.nn.Module):
725
            def __init__(self, x):
726
                super().__init__()
727
                self.x = x
728

729
            def forward(self, inp, *, other_inp):
730
                return inp * self.x + other_inp
731

732
        a = {'x': torch.zeros(2, 3)}
733
        mod = Foo(torch.randn(2, 3))
734
        inp, other_inp = torch.randn(2, 3), torch.randn(2, 3)
735
        with self.assertRaisesRegex(TypeError, "missing 1 required keyword-only argument: 'other_inp'"):
736
            functional_call(mod, a, inp)
737
        res = functional_call(mod, a, inp, {'other_inp': other_inp})
738
        self.assertEqual(res, other_inp)
739
        res_1 = functional_call(mod, a, (), {'inp': inp, 'other_inp': other_inp})
740
        self.assertEqual(res, res_1)
741

742
    def test_functional_call_tuple_dicts(self):
743
        mod = MockModule()
744
        x = torch.rand((1, 1))
745
        parameters = {k: torch.ones_like(v) for k, v in mod.named_parameters()}
746
        buffers = {k: torch.zeros_like(v) for k, v in mod.named_buffers()}
747

748
        # two dictionaries
749
        res = torch.func.functional_call(mod, (parameters, buffers), x)
750
        self.assertEqual(res, x + 1)
751

752
        # no dictionaries
753
        res = torch.func.functional_call(mod, (), x)
754
        self.assertEqual(res, mod(x))
755

756
        # three dictonaries
757
        a = ({'l1.weight': torch.ones(1, 1)}, {'l1.bias': torch.ones(1)}, {'buffer': torch.zeros(1)})
758
        res = torch.func.functional_call(mod, a, x)
759
        self.assertEqual(res, x + 1)
760

761
    def test_functional_call_multiple_dicts_error(self):
762
        mod = MockModule()
763
        x = torch.rand((1, 1))
764
        parameters = {'l1.weight': torch.zeros((1, 1)), 'l1.bias': torch.zeros((1, 1))}
765
        repeated_parameters = {'l1.weight': torch.ones((1, 1))}
766
        with self.assertRaisesRegex(
767
            ValueError,
768
            re.escape("['l1.weight'] appeared in multiple dictionaries"),
769
        ):
770
            torch.func.functional_call(mod, (parameters, repeated_parameters), x)
771

772
    @parametrize("functional_call", [
773
        subtest(torch.func.functional_call, "torch_func"),
774
        subtest(stateless.functional_call, "stateless")
775
    ])
776
    def test_functional_call_member_reference(self, functional_call):
777
        class Module(torch.nn.Module):
778
            def __init__(self) -> None:
779
                super().__init__()
780
                self.l1 = torch.nn.Linear(1, 1)
781
                self.buffer = torch.nn.Buffer(torch.ones(1))
782

783
            def forward(self, x):
784
                parameters = tuple(self.parameters())
785
                buffers = tuple(self.buffers())
786
                return self.l1(x) + self.buffer, parameters, buffers
787

788
        module = Module()
789
        weight = torch.tensor([[2.0]])
790
        bias = torch.tensor([5.0])
791
        buffer = torch.tensor([3.0])
792
        extra = torch.tensor([1.0])
793
        extra_p = torch.nn.Parameter(extra)
794

795
        # All weights
796
        parameters = {'l1.weight': weight,
797
                      'l1.bias': bias,
798
                      'buffer': buffer}
799
        x = torch.randn(1, 1)
800
        out, parameters, buffers = functional_call(module, parameters, x)
801
        self.assertEqual(out, x * weight + bias + buffer)
802
        self.assertEqual(parameters, (weight, bias))
803
        self.assertEqual(buffers, (buffer,))
804
        self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, bias))))
805
        self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (buffer,))))
806

807
        # Some weights
808
        parameters = {'l1.weight': weight}
809
        x = torch.randn(1, 1)
810
        out, parameters, buffers = functional_call(module, parameters, x)
811
        self.assertEqual(out, x * weight + module.l1.bias + module.buffer)
812
        self.assertEqual(parameters, (weight, module.l1.bias))
813
        self.assertEqual(buffers, (module.buffer,))
814
        self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, module.l1.bias))))
815
        self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,))))
816

817
        # All weights with extra keys
818
        parameters = {'l1.weight': weight,
819
                      'l1.bias': bias,
820
                      'buffer': buffer,
821
                      'l1.extra': extra}
822
        x = torch.randn(1, 1)
823
        out, parameters, buffers = functional_call(module, parameters, x)
824
        self.assertEqual(out, x * weight + bias + buffer)
825
        self.assertEqual(parameters, (weight, bias))
826
        self.assertEqual(buffers, (buffer,))
827
        self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, bias))))
828
        self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (buffer,))))
829

830
        # All weights with extra keys with parameters
831
        parameters = {'l1.weight': weight,
832
                      'l1.bias': bias,
833
                      'buffer': buffer,
834
                      'l1.extra': extra_p}
835
        x = torch.randn(1, 1)
836
        out, parameters, buffers = functional_call(module, parameters, x)
837
        self.assertEqual(out, x * weight + bias + buffer)
838
        self.assertEqual(parameters, (weight, bias, extra_p))
839
        self.assertEqual(buffers, (buffer,))
840
        self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, bias, extra_p))))
841
        self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (buffer,))))
842

843
        # Some weights with extra keys
844
        parameters = {'l1.weight': weight,
845
                      'l1.extra': extra}
846
        x = torch.randn(1, 1)
847
        out, parameters, buffers = functional_call(module, parameters, x)
848
        self.assertEqual(out, x * weight + module.l1.bias + module.buffer)
849
        self.assertEqual(parameters, (weight, module.l1.bias))
850
        self.assertEqual(buffers, (module.buffer))
851
        self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, module.l1.bias))))
852
        self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,))))
853

854
        # Some weights with extra keys with parameters
855
        parameters = {'l1.weight': weight,
856
                      'l1.extra': extra_p}
857
        x = torch.randn(1, 1)
858
        out, parameters, buffers = functional_call(module, parameters, x)
859
        self.assertEqual(out, x * weight + module.l1.bias + module.buffer)
860
        self.assertEqual(parameters, (weight, module.l1.bias, extra_p))
861
        self.assertEqual(buffers, (module.buffer))
862
        self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, module.l1.bias, extra_p))))
863
        self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,))))
864

865
        # Set None
866
        parameters = {'l1.weight': weight,
867
                      'l1.bias': None}
868
        x = torch.randn(1, 1)
869
        out, parameters, buffers = functional_call(module, parameters, x)
870
        self.assertEqual(out, x * weight + module.buffer)
871
        self.assertEqual(parameters, (weight,))
872
        self.assertEqual(buffers, (module.buffer))
873
        self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight,))))
874
        self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,))))
875

876

877
class TestStatelessDeprecation(TestCase):
878
    def test_private_stateless_warns(self):
879
        script = """
880
import torch
881
import warnings
882

883
with warnings.catch_warnings(record=True) as w:
884
    from torch.nn.utils import _stateless
885

886
exit(len(w))
887
"""
888
        try:
889
            subprocess.check_output(
890
                [sys.executable, '-W', 'always', '-c', script],
891
                stderr=subprocess.STDOUT,
892
                # On Windows, opening the subprocess with the default CWD makes `import torch`
893
                # fail, so just set CWD to this script's directory
894
                cwd=os.path.dirname(os.path.realpath(__file__)),)
895
        except subprocess.CalledProcessError as e:
896
            self.assertEqual(e.returncode, 1)
897
        else:
898
            self.assertTrue(False, "No warning was raised.")
899

900
    def test_stateless_functional_call_warns(self):
901
        m = torch.nn.Linear(1, 1)
902
        params = dict(m.named_parameters())
903
        x = torch.randn(3, 1)
904
        with self.assertWarnsRegex(FutureWarning, "Please use `torch.func.functional_call`"):
905
            stateless.functional_call(m, params, x)
906

907
class TestPythonOptimizeMode(TestCase):
908
    def test_runs_with_optimize_flag(self):
909
        script = "import torch; import torch._functorch.deprecated"
910
        try:
911
            subprocess.check_output(
912
                [sys.executable, "-OO", "-c", script],
913
                stderr=subprocess.STDOUT,
914
                # On Windows, opening the subprocess with the default CWD makes `import torch`
915
                # fail, so just set CWD to this script's directory
916
                cwd=os.path.dirname(os.path.realpath(__file__)),)
917
        except subprocess.CalledProcessError as e:
918
            self.assertFalse(e.returncode, "Import failed while running python in optimized mode")
919

920

921
instantiate_parametrized_tests(
922
    TestStatelessFunctionalAPI,
923
)
924

925
if __name__ == '__main__':
926
    run_tests()
927

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.