pytorch

Форк
0
/
test_lazy_modules.py 
844 строки · 31.6 Кб
1
# Owner(s): ["module: nn"]
2
import pickle
3
import unittest
4

5
import torch
6
import torch.nn as nn
7
from torch.nn import Buffer, Parameter
8
from torch.nn.parameter import UninitializedBuffer, UninitializedParameter
9
from torch.testing._internal.common_cuda import TEST_CUDA
10
from torch.testing._internal.common_utils import (
11
    run_tests,
12
    suppress_warnings,
13
    TEST_PRIVATEUSE1,
14
    TestCase,
15
)
16

17

18
class LazyModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
19
    pass
20

21

22
class TestLazyModules(TestCase):
23
    @suppress_warnings
24
    def test_lazy_module_parameter(self):
25
        module = LazyModule()
26
        module.register_parameter("test_param", UninitializedParameter())
27
        self.assertTrue(module.has_uninitialized_params())
28
        state_dict = module.state_dict()
29
        self.assertIsInstance(state_dict["test_param"], UninitializedParameter)
30
        new_module = LazyModule()
31
        # An error is raised when there is an attempt to replace an existing parameter
32
        # with an uninitialized one
33
        new_module.register_parameter("test_param", nn.Parameter(torch.ones(5, 5)))
34
        with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"):
35
            new_module.load_state_dict(state_dict)
36
        # Uninitialized parameters are overriden when the state dict to be loaded contains a valid one
37
        new_module = LazyModule()
38
        new_module.register_parameter("test_param", nn.Parameter(torch.ones(5, 5)))
39
        module.load_state_dict(new_module.state_dict())
40
        self.assertEqual(module.test_param, torch.ones((5, 5)))
41

42
        # Uninitialized parameters are left unchanged
43
        module = LazyModule()
44
        module.register_parameter("test_param", UninitializedParameter())
45
        self.assertTrue(module.has_uninitialized_params())
46

47
        new_module = LazyModule()
48
        new_module.register_parameter("test_param", UninitializedParameter())
49
        module.load_state_dict(new_module.state_dict())
50
        self.assertTrue(module.has_uninitialized_params())
51

52
    @suppress_warnings
53
    def test_lazy_module_buffer(self):
54
        module = LazyModule()
55
        module.test_buffer = UninitializedBuffer()
56
        self.assertTrue(module.has_uninitialized_params())
57
        state_dict = module.state_dict()
58
        self.assertIsInstance(state_dict["test_buffer"], UninitializedBuffer)
59
        new_module = LazyModule()
60
        # An error is raised when there is an attempt to replace an existing parameter
61
        # with an uninitialized one
62
        new_module.test_buffer = Buffer(torch.ones(5, 5))
63
        with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"):
64
            new_module.load_state_dict(state_dict)
65
        # Uninitialized parameters are overriden when the state dict to be loaded contains a valid one
66
        new_module = LazyModule()
67
        new_module.test_buffer = Buffer(torch.ones(5, 5))
68
        module.load_state_dict(new_module.state_dict())
69
        self.assertEqual(module.test_buffer, torch.ones((5, 5)))
70

71
        # Uninitialized parameters are left unchanged
72
        module = LazyModule()
73
        module.test_buffer = UninitializedBuffer()
74
        self.assertTrue(module.has_uninitialized_params())
75

76
        new_module = LazyModule()
77
        new_module.test_buffer = UninitializedBuffer()
78
        module.load_state_dict(new_module.state_dict())
79
        module.load_state_dict(new_module.state_dict())
80
        self.assertTrue(module.has_uninitialized_params())
81

82
    @suppress_warnings
83
    def test_lazy_module_jit_param(self):
84
        module = LazyModule()
85
        module.register_parameter("test_param", UninitializedParameter())
86
        self.assertTrue(module.has_uninitialized_params())
87
        with self.assertRaisesRegex(RuntimeError, "run a forward pass"):
88
            torch.jit.script(module)
89

90
    @suppress_warnings
91
    def test_lazy_module_jit_buffer(self):
92
        module = LazyModule()
93
        module.test_buffer = UninitializedBuffer()
94
        self.assertTrue(module.has_uninitialized_params())
95
        with self.assertRaisesRegex(RuntimeError, "run a forward pass"):
96
            torch.jit.script(module)
97

98
    @suppress_warnings
99
    def test_lazy_share_memory_param(self):
100
        module = LazyModule()
101
        module.register_parameter("test_param", UninitializedParameter())
102
        self.assertTrue(module.has_uninitialized_params())
103
        with self.assertRaisesRegex(RuntimeError, "share memory on an uninitialized"):
104
            module.share_memory()
105

106
    @suppress_warnings
107
    def test_lazy_share_memory_buffer(self):
108
        module = LazyModule()
109
        module.test_buffer = UninitializedBuffer()
110
        self.assertTrue(module.has_uninitialized_params())
111
        with self.assertRaisesRegex(RuntimeError, "share memory on an uninitialized"):
112
            module.share_memory()
113

114
    @suppress_warnings
115
    def test_linear(self):
116
        module = nn.LazyLinear(10)
117
        self.assertIsInstance(module.weight, UninitializedParameter)
118
        self.assertIsInstance(module.bias, UninitializedParameter)
119
        input = torch.ones(5, 5)
120
        module(input)
121
        self.assertIsInstance(module, nn.Linear)
122
        self.assertNotIsInstance(module, nn.LazyLinear)
123
        self.assertTrue(module.weight.shape == (10, 5))
124
        self.assertTrue(module.bias.shape == (10,))
125
        y = module(input)
126
        self.assertTrue(
127
            torch.equal(
128
                torch.nn.functional.linear(input, module.weight, module.bias), y
129
            )
130
        )
131

132
    @suppress_warnings
133
    def test_lazy_linear_pickle(self):
134
        module = nn.LazyLinear(10)
135
        self.assertIsInstance(module.weight, UninitializedParameter)
136
        self.assertIsInstance(module.bias, UninitializedParameter)
137
        module = pickle.loads(pickle.dumps(module))
138
        self.assertIsInstance(module, nn.LazyLinear)
139
        self.assertIsInstance(module.weight, UninitializedParameter)
140
        self.assertIsInstance(module.bias, UninitializedParameter)
141
        input = torch.ones(5, 5)
142
        module(input)  # fully materialized
143
        new_module = pickle.loads(pickle.dumps(module))
144
        self.assertIsInstance(new_module, nn.Linear)
145
        self.assertNotIsInstance(new_module, nn.LazyLinear)
146
        self.assertTrue(new_module.weight.shape == (10, 5))
147
        self.assertNotIsInstance(new_module.weight, UninitializedParameter)
148
        self.assertTrue(new_module.bias.shape == (10,))
149
        self.assertNotIsInstance(new_module.bias, UninitializedParameter)
150

151
    @suppress_warnings
152
    def test_linear_state(self):
153
        module = nn.Linear(5, 10)
154
        lazy_module = nn.LazyLinear(10)
155
        lazy_module.load_state_dict(module.state_dict())
156
        # Parameters have been initialized but the module won't become a full
157
        # Linear one until the first iteration. This is due to
158
        # limitations on the state_dict loading logic
159
        self.assertFalse(lazy_module.has_uninitialized_params())
160
        self.assertTrue(lazy_module.weight.shape == (10, 5))
161
        self.assertTrue(lazy_module.bias.shape == (10,))
162

163
        module = nn.Linear(5, 10)
164
        lazy_module = nn.LazyLinear(10)
165
        with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"):
166
            module.load_state_dict(lazy_module.state_dict())
167

168
    def _check_lazy_conv(
169
        self,
170
        cls,
171
        lazy_cls,
172
        func,
173
        init_args,
174
        input_shape,
175
        expected_weight_shape,
176
        expected_bias_shape,
177
        *forward_args,
178
        **forward_kwargs,
179
    ):
180
        module = lazy_cls(*init_args)
181
        self.assertIsInstance(module.weight, UninitializedParameter)
182
        if module.bias is not None:
183
            self.assertIsInstance(module.bias, UninitializedParameter)
184
        input = torch.ones(*input_shape)
185
        module(input, *forward_args, **forward_kwargs)
186
        self.assertIsInstance(module, cls)
187
        self.assertNotIsInstance(module, lazy_cls)
188
        self.assertEqual(module.weight.shape, expected_weight_shape)
189
        if module.bias is not None:
190
            self.assertEqual(module.bias.shape, expected_bias_shape)
191
        y = module(input)
192
        self.assertTrue(torch.equal(func(input, module.weight, module.bias), y))
193

194
    def _check_lazy_conv_pickle(
195
        self,
196
        cls,
197
        lazy_cls,
198
        init_args,
199
        input_shape,
200
        expected_weight_shape,
201
        expected_bias_shape,
202
    ):
203
        module = lazy_cls(*init_args)
204
        self.assertIsInstance(module.weight, UninitializedParameter)
205
        if module.bias is not None:
206
            self.assertIsInstance(module.bias, UninitializedParameter)
207
        module = pickle.loads(pickle.dumps(module))
208
        self.assertIsInstance(module, lazy_cls)
209
        self.assertIsInstance(module.weight, UninitializedParameter)
210
        if module.bias is not None:
211
            self.assertIsInstance(module.bias, UninitializedParameter)
212
        input = torch.ones(*input_shape)
213
        module(input)  # fully materialized
214
        new_module = pickle.loads(pickle.dumps(module))
215
        self.assertIsInstance(new_module, cls)
216
        self.assertNotIsInstance(new_module, lazy_cls)
217
        self.assertEqual(new_module.weight.shape, expected_weight_shape)
218
        self.assertNotIsInstance(new_module.weight, UninitializedParameter)
219
        if new_module.bias is not None:
220
            self.assertEqual(new_module.bias.shape, expected_bias_shape)
221
            self.assertNotIsInstance(new_module.bias, UninitializedParameter)
222

223
    def _check_lazy_conv_state(
224
        self, gen_module, gen_lazy_module, expected_weight_shape, expected_bias_shape
225
    ):
226
        module = gen_module()
227
        lazy_module = gen_lazy_module()
228
        lazy_module.load_state_dict(module.state_dict())
229
        # Parameters have been initialized but the module won't become a full
230
        # Conv one until the first iteration. This is due to
231
        # limitations on the state_dict loading logic
232
        self.assertFalse(lazy_module.has_uninitialized_params())
233
        self.assertEqual(lazy_module.weight.shape, expected_weight_shape)
234
        if lazy_module.bias is not None:
235
            self.assertEqual(lazy_module.bias.shape, expected_bias_shape)
236

237
        module = gen_module()
238
        lazy_module = gen_lazy_module()
239
        with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"):
240
            module.load_state_dict(lazy_module.state_dict())
241

242
    def test_lazy_pre_forward_hook(self):
243
        """
244
        This test is to test whether lazymodule can register other pre-forward hook
245
        functions successfully.
246
        """
247

248
        class TestModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
249
            def initialize_parameters(self, input):
250
                return None
251

252
            def forward(self, input):
253
                return input
254

255
        def hook_function(module, input):
256
            return input[0] + 1
257

258
        module = TestModule()
259
        module.register_forward_pre_hook(hook_function)
260
        output = module(torch.zeros(2, 2))
261
        self.assertEqual(output, torch.ones(2, 2))
262

263
    def test_lazy_forward_hook(self):
264
        """
265
        This test is to test whether lazymodule can register other forward hook
266
        functions successfully.
267
        """
268

269
        class TestModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
270
            def initialize_parameters(self, input):
271
                return None
272

273
            def forward(self, input):
274
                return input
275

276
        def hook_function(module, input, output):
277
            return input[0] + 1
278

279
        module = TestModule()
280
        module.register_forward_hook(hook_function)
281
        output = module(torch.zeros(2, 2))
282
        self.assertEqual(output, torch.ones(2, 2))
283

284
    @suppress_warnings
285
    def test_lazy_conv1d(self):
286
        self._check_lazy_conv(
287
            nn.Conv1d,
288
            nn.LazyConv1d,
289
            torch.nn.functional.conv1d,
290
            (32, 2),
291
            (192, 16, 50),
292
            (32, 16, 2),
293
            (32,),
294
        )
295

296
    @suppress_warnings
297
    def test_lazy_conv1d_pickle(self):
298
        self._check_lazy_conv_pickle(
299
            nn.Conv1d, nn.LazyConv1d, (32, 2), (192, 16, 50), (32, 16, 2), (32,)
300
        )
301

302
    @suppress_warnings
303
    def test_lazy_conv1d_state(self):
304
        self._check_lazy_conv_state(
305
            lambda: nn.Conv1d(16, 32, 2),
306
            lambda: nn.LazyConv1d(32, 2),
307
            (32, 16, 2),
308
            (32,),
309
        )
310

311
    @suppress_warnings
312
    def test_lazy_conv2d(self):
313
        self._check_lazy_conv(
314
            nn.Conv2d,
315
            nn.LazyConv2d,
316
            torch.nn.functional.conv2d,
317
            (32, 2),
318
            (192, 16, 8, 6),
319
            (32, 16, 2, 2),
320
            (32,),
321
        )
322

323
    @suppress_warnings
324
    def test_lazy_conv2d_pickle(self):
325
        self._check_lazy_conv_pickle(
326
            nn.Conv2d, nn.LazyConv2d, (32, 2), (192, 16, 8, 6), (32, 16, 2, 2), (32,)
327
        )
328

329
    @suppress_warnings
330
    def test_lazy_conv2d_state(self):
331
        self._check_lazy_conv_state(
332
            lambda: nn.Conv2d(16, 32, 2),
333
            lambda: nn.LazyConv2d(32, 2),
334
            (32, 16, 2, 2),
335
            (32,),
336
        )
337

338
    @suppress_warnings
339
    def test_lazy_conv3d(self):
340
        self._check_lazy_conv(
341
            nn.Conv3d,
342
            nn.LazyConv3d,
343
            torch.nn.functional.conv3d,
344
            (32, 2),
345
            (192, 16, 8, 7, 6),
346
            (32, 16, 2, 2, 2),
347
            (32,),
348
        )
349

350
    @suppress_warnings
351
    def test_lazy_conv3d_pickle(self):
352
        self._check_lazy_conv_pickle(
353
            nn.Conv3d,
354
            nn.LazyConv3d,
355
            (32, 2),
356
            (192, 16, 8, 7, 6),
357
            (32, 16, 2, 2, 2),
358
            (32,),
359
        )
360

361
    @suppress_warnings
362
    def test_lazy_conv3d_state(self):
363
        self._check_lazy_conv_state(
364
            lambda: nn.Conv3d(16, 32, 2),
365
            lambda: nn.LazyConv3d(32, 2),
366
            (32, 16, 2, 2, 2),
367
            (32,),
368
        )
369

370
    @suppress_warnings
371
    def test_lazy_conv_transposed1d(self):
372
        self._check_lazy_conv(
373
            nn.ConvTranspose1d,
374
            nn.LazyConvTranspose1d,
375
            torch.nn.functional.conv_transpose1d,
376
            (32, 2),
377
            (192, 16, 50),
378
            (16, 32, 2),
379
            (32,),
380
        )
381

382
    @suppress_warnings
383
    def test_lazy_conv_transpose1d_kwargs(self):
384
        self._check_lazy_conv(
385
            nn.ConvTranspose1d,
386
            nn.LazyConvTranspose1d,
387
            torch.nn.functional.conv_transpose1d,
388
            (32, 2),
389
            (192, 16, 50),
390
            (16, 32, 2),
391
            (32,),
392
            output_size=(51,),
393
        )
394

395
    @suppress_warnings
396
    def test_lazy_conv_transpose1d_pickle(self):
397
        self._check_lazy_conv_pickle(
398
            nn.ConvTranspose1d,
399
            nn.LazyConvTranspose1d,
400
            (32, 2),
401
            (192, 16, 50),
402
            (16, 32, 2),
403
            (32,),
404
        )
405

406
    @suppress_warnings
407
    def test_lazy_conv_transpose1d_state(self):
408
        self._check_lazy_conv_state(
409
            lambda: nn.ConvTranspose1d(16, 32, 2),
410
            lambda: nn.LazyConvTranspose1d(32, 2),
411
            (16, 32, 2),
412
            (32,),
413
        )
414

415
    @suppress_warnings
416
    def test_lazy_conv_transpose2d(self):
417
        self._check_lazy_conv(
418
            nn.ConvTranspose2d,
419
            nn.LazyConvTranspose2d,
420
            torch.nn.functional.conv_transpose2d,
421
            (32, 2),
422
            (192, 16, 8, 6),
423
            (16, 32, 2, 2),
424
            (32,),
425
        )
426

427
    @suppress_warnings
428
    def test_lazy_conv_transpose2d_kwargs(self):
429
        self._check_lazy_conv(
430
            nn.ConvTranspose2d,
431
            nn.LazyConvTranspose2d,
432
            torch.nn.functional.conv_transpose2d,
433
            (32, 2),
434
            (192, 16, 8, 6),
435
            (16, 32, 2, 2),
436
            (32,),
437
            output_size=(9, 7),
438
        )
439

440
    @suppress_warnings
441
    def test_lazy_conv_transpose2d_pickle(self):
442
        self._check_lazy_conv_pickle(
443
            nn.ConvTranspose2d,
444
            nn.LazyConvTranspose2d,
445
            (32, 2),
446
            (192, 16, 8, 6),
447
            (16, 32, 2, 2),
448
            (32,),
449
        )
450

451
    @suppress_warnings
452
    def test_lazy_conv_transpose2d_state(self):
453
        self._check_lazy_conv_state(
454
            lambda: nn.ConvTranspose2d(16, 32, 2),
455
            lambda: nn.LazyConvTranspose2d(32, 2),
456
            (16, 32, 2, 2),
457
            (32,),
458
        )
459

460
    @suppress_warnings
461
    def test_lazy_conv_transpose3d(self):
462
        self._check_lazy_conv(
463
            nn.ConvTranspose3d,
464
            nn.LazyConvTranspose3d,
465
            torch.nn.functional.conv_transpose3d,
466
            (32, 2),
467
            (192, 16, 8, 7, 6),
468
            (16, 32, 2, 2, 2),
469
            (32,),
470
        )
471

472
    @suppress_warnings
473
    def test_lazy_conv_transpose3d_kwargs(self):
474
        self._check_lazy_conv(
475
            nn.ConvTranspose3d,
476
            nn.LazyConvTranspose3d,
477
            torch.nn.functional.conv_transpose3d,
478
            (32, 2),
479
            (192, 16, 8, 7, 6),
480
            (16, 32, 2, 2, 2),
481
            (32,),
482
            output_size=(9, 8, 7),
483
        )
484

485
    @suppress_warnings
486
    def test_lazy_conv_transpose3d_pickle(self):
487
        self._check_lazy_conv_pickle(
488
            nn.ConvTranspose3d,
489
            nn.LazyConvTranspose3d,
490
            (32, 2),
491
            (192, 16, 8, 7, 6),
492
            (16, 32, 2, 2, 2),
493
            (32,),
494
        )
495

496
    @suppress_warnings
497
    def test_lazy_conv_transpose3d_state(self):
498
        self._check_lazy_conv_state(
499
            lambda: nn.ConvTranspose3d(16, 32, 2),
500
            lambda: nn.LazyConvTranspose3d(32, 2),
501
            (16, 32, 2, 2, 2),
502
            (32,),
503
        )
504

505
    def _check_lazy_norm(self, cls, lazy_cls, input_shape):
506
        for affine in [False, True]:
507
            for track_running_stats in [False, True]:
508
                lazy_module = lazy_cls(
509
                    affine=affine, track_running_stats=track_running_stats
510
                )
511

512
                if affine:
513
                    self.assertIsInstance(lazy_module.weight, UninitializedParameter)
514
                    self.assertIsInstance(lazy_module.bias, UninitializedParameter)
515
                if track_running_stats:
516
                    self.assertIsInstance(lazy_module.running_mean, UninitializedBuffer)
517
                    self.assertIsInstance(lazy_module.running_var, UninitializedBuffer)
518

519
                input = torch.ones(*input_shape)
520
                lazy_output = lazy_module(input)
521
                self.assertIsInstance(lazy_module, cls)
522
                self.assertNotIsInstance(lazy_module, lazy_cls)
523

524
                num_features = input_shape[1]
525
                module = cls(
526
                    num_features, affine=affine, track_running_stats=track_running_stats
527
                )
528
                expected_output = module(input)
529

530
                self.assertEqual(lazy_output, expected_output)
531
                if module.weight is not None:
532
                    self.assertEqual(lazy_module.weight.shape, module.weight.shape)
533
                    self.assertEqual(lazy_module.weight, module.weight)
534
                if module.bias is not None:
535
                    self.assertEqual(lazy_module.bias.shape, module.bias.shape)
536
                    self.assertEqual(lazy_module.bias, module.bias)
537
                if module.running_mean is not None:
538
                    self.assertEqual(
539
                        lazy_module.running_mean.shape, module.running_mean.shape
540
                    )
541
                    self.assertEqual(lazy_module.running_mean, module.running_mean)
542
                if module.running_var is not None:
543
                    self.assertEqual(
544
                        lazy_module.running_var.shape, module.running_var.shape
545
                    )
546
                    self.assertEqual(lazy_module.running_var, module.running_var)
547
                if module.num_batches_tracked is not None:
548
                    self.assertEqual(
549
                        lazy_module.num_batches_tracked.shape,
550
                        module.num_batches_tracked.shape,
551
                    )
552
                    self.assertEqual(
553
                        lazy_module.num_batches_tracked, module.num_batches_tracked
554
                    )
555

556
    def _check_lazy_norm_pickle(self, cls, lazy_cls, input_shape):
557
        for affine in [False, True]:
558
            for track_running_stats in [False, True]:
559
                module = lazy_cls(
560
                    affine=affine, track_running_stats=track_running_stats
561
                )
562
                module = pickle.loads(pickle.dumps(module))
563

564
                self.assertIsInstance(module, lazy_cls)
565
                if affine:
566
                    self.assertIsInstance(module.weight, UninitializedParameter)
567
                    self.assertIsInstance(module.bias, UninitializedParameter)
568
                if track_running_stats:
569
                    self.assertIsInstance(module.running_mean, UninitializedBuffer)
570
                    self.assertIsInstance(module.running_var, UninitializedBuffer)
571

572
                input = torch.ones(*input_shape)
573
                module(input)  # fully materialized
574
                module = pickle.loads(pickle.dumps(module))
575

576
                self.assertNotIsInstance(module, lazy_cls)
577
                self.assertIsInstance(module, cls)
578
                if affine:
579
                    self.assertNotIsInstance(module.weight, UninitializedParameter)
580
                    self.assertNotIsInstance(module.bias, UninitializedParameter)
581
                if track_running_stats:
582
                    self.assertNotIsInstance(module.running_mean, UninitializedBuffer)
583
                    self.assertNotIsInstance(module.running_var, UninitializedBuffer)
584

585
    def _check_lazy_batchnorm_state(self, cls, lazy_cls):
586
        module = cls(10)
587
        lazy_module = lazy_cls(affine=True, track_running_stats=True)
588
        lazy_module.load_state_dict(module.state_dict())
589
        # Parameters have been initialized but the module won't become a full
590
        # Conv one until the first iteration. This is due to
591
        # limitations on the state_dict loading logic
592
        self.assertFalse(lazy_module.has_uninitialized_params())
593
        self.assertEqual(lazy_module.weight.shape, (10,))
594
        self.assertEqual(lazy_module.bias.shape, (10,))
595
        self.assertEqual(lazy_module.running_mean.shape, (10,))
596
        self.assertEqual(lazy_module.running_var.shape, (10,))
597

598
        module = cls(10)
599
        lazy_module = lazy_cls()
600
        with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"):
601
            module.load_state_dict(lazy_module.state_dict())
602

603
    def _check_lazy_instancenorm_state(self, cls, lazy_cls):
604
        for affine in [False, True]:
605
            for track_running_stats in [False, True]:
606
                module = cls(10, affine=affine, track_running_stats=track_running_stats)
607
                lazy_module = lazy_cls(
608
                    affine=affine, track_running_stats=track_running_stats
609
                )
610
                lazy_module.load_state_dict(module.state_dict())
611
                # Parameters have been initialized but the module won't become a full
612
                # InstanceNorm one until the first iteration. This is due to
613
                # limitations on the state_dict loading logic
614
                self.assertFalse(lazy_module.has_uninitialized_params())
615
                if affine:
616
                    self.assertEqual(lazy_module.weight.shape, (10,))
617
                    self.assertEqual(lazy_module.bias.shape, (10,))
618
                if track_running_stats:
619
                    self.assertEqual(lazy_module.running_mean.shape, (10,))
620
                    self.assertEqual(lazy_module.running_var.shape, (10,))
621

622
        module = cls(10, affine=True, track_running_stats=True)
623
        lazy_module = lazy_cls(affine=True, track_running_stats=True)
624
        with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"):
625
            module.load_state_dict(lazy_module.state_dict())
626

627
    def _check_lazy_norm_with_dict_input(self, cls, lazy_cls, input_shape):
628
        input = {"input": torch.ones(*input_shape)}
629

630
        lazy_module = lazy_cls()
631
        lazy_output = lazy_module(**input)
632

633
        num_features = input_shape[1]
634
        module = cls(num_features)
635
        expected_output = module(**input)
636

637
        self.assertEqual(lazy_output, expected_output)
638

639
    def test_lazy_batchnorm1d(self):
640
        self._check_lazy_norm(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6))
641
        self._check_lazy_norm(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 6))
642

643
    def test_lazy_batchnorm1d_pickle(self):
644
        self._check_lazy_norm_pickle(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6))
645
        self._check_lazy_norm_pickle(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 6))
646

647
    def test_lazy_batchnorm1d_state(self):
648
        self._check_lazy_batchnorm_state(nn.BatchNorm1d, nn.LazyBatchNorm1d)
649
        self._check_lazy_batchnorm_state(nn.BatchNorm1d, nn.LazyBatchNorm1d)
650

651
    def test_lazy_batchnorm2d(self):
652
        self._check_lazy_norm(nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7))
653

654
    def test_lazy_batchnorm2d_pickle(self):
655
        self._check_lazy_norm_pickle(nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7))
656

657
    def test_lazy_batchnorm2d_state(self):
658
        self._check_lazy_batchnorm_state(nn.BatchNorm2d, nn.LazyBatchNorm2d)
659
        self._check_lazy_batchnorm_state(nn.BatchNorm2d, nn.LazyBatchNorm2d)
660

661
    def test_lazy_batchnorm3d(self):
662
        self._check_lazy_norm(nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8))
663

664
    def test_lazy_batchnorm3d_pickle(self):
665
        self._check_lazy_norm_pickle(
666
            nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8)
667
        )
668

669
    def test_lazy_batchnorm3d_state(self):
670
        self._check_lazy_batchnorm_state(nn.BatchNorm3d, nn.LazyBatchNorm3d)
671
        self._check_lazy_batchnorm_state(nn.BatchNorm3d, nn.LazyBatchNorm3d)
672

673
    def test_lazy_instancenorm1d(self):
674
        self._check_lazy_norm(nn.InstanceNorm1d, nn.LazyInstanceNorm1d, (16, 3, 6))
675

676
    def test_lazy_instancenorm1d_pickle(self):
677
        self._check_lazy_norm_pickle(
678
            nn.InstanceNorm1d, nn.LazyInstanceNorm1d, (16, 3, 6)
679
        )
680

681
    def test_lazy_instancenorm1d_state(self):
682
        self._check_lazy_instancenorm_state(nn.InstanceNorm1d, nn.LazyInstanceNorm1d)
683
        self._check_lazy_instancenorm_state(nn.InstanceNorm1d, nn.LazyInstanceNorm1d)
684

685
    def test_lazy_instancenorm2d(self):
686
        self._check_lazy_norm(nn.InstanceNorm2d, nn.LazyInstanceNorm2d, (16, 3, 6, 7))
687

688
    def test_lazy_instancenorm2d_pickle(self):
689
        self._check_lazy_norm_pickle(
690
            nn.InstanceNorm2d, nn.LazyInstanceNorm2d, (16, 3, 6, 7)
691
        )
692

693
    def test_lazy_instancenorm2d_state(self):
694
        self._check_lazy_instancenorm_state(nn.InstanceNorm2d, nn.LazyInstanceNorm2d)
695
        self._check_lazy_instancenorm_state(nn.InstanceNorm2d, nn.LazyInstanceNorm2d)
696

697
    def test_lazy_instancenorm3d(self):
698
        self._check_lazy_norm(
699
            nn.InstanceNorm3d, nn.LazyInstanceNorm3d, (16, 3, 6, 7, 8)
700
        )
701

702
    def test_lazy_instancenorm3d_pickle(self):
703
        self._check_lazy_norm_pickle(
704
            nn.InstanceNorm3d, nn.LazyInstanceNorm3d, (16, 3, 6, 7, 8)
705
        )
706

707
    def test_lazy_instancenorm3d_state(self):
708
        self._check_lazy_instancenorm_state(nn.InstanceNorm3d, nn.LazyInstanceNorm3d)
709
        self._check_lazy_instancenorm_state(nn.InstanceNorm3d, nn.LazyInstanceNorm3d)
710

711
    def test_lazy_batchnorm_with_dict_input(self):
712
        self._check_lazy_norm_with_dict_input(
713
            nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6)
714
        )
715
        self._check_lazy_norm_with_dict_input(
716
            nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7)
717
        )
718
        self._check_lazy_norm_with_dict_input(
719
            nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8)
720
        )
721

722
    @suppress_warnings
723
    def test_materialize_dtype(self):
724
        module = LazyModule()
725
        module.register_parameter("test_param", UninitializedParameter())
726
        module.test_param.materialize(10)
727
        self.assertTrue(module.test_param.dtype == torch.get_default_dtype())
728
        module = LazyModule()
729
        module.register_parameter("test_param", UninitializedParameter())
730
        module.half()
731
        module.test_param.materialize(10)
732
        self.assertTrue(module.test_param.dtype == torch.float16)
733

734
    @unittest.skipIf(
735
        not (TEST_CUDA or TEST_PRIVATEUSE1), "CUDA and PRIVATEUSE1 not available"
736
    )
737
    @suppress_warnings
738
    def test_materialize_device(self):
739
        module = LazyModule()
740
        module.register_parameter("test_param", UninitializedParameter())
741
        module.test_param.materialize(10)
742
        self.assertTrue(module.test_param.device.type == "cpu")
743
        if TEST_CUDA:
744
            device = "cuda"
745
        elif TEST_PRIVATEUSE1:
746
            device = torch._C._get_privateuse1_backend_name()
747
        module = LazyModule()
748
        module.register_parameter("test_param", UninitializedParameter())
749
        module.to(device)
750
        module.test_param.materialize(10)
751
        self.assertTrue(module.test_param.device.type == device)
752

753
    @suppress_warnings
754
    def test_chained_initialization(self):
755
        class MyNetwork(torch.nn.Module):
756
            def __init__(self) -> None:
757
                super().__init__()
758
                self.linear_1 = torch.nn.LazyLinear(15)
759
                self.linear_2 = torch.nn.LazyLinear(10)
760

761
            def forward(self, x):
762
                y = self.linear_1(x)
763
                return self.linear_2(y)
764

765
        net = MyNetwork()
766
        net(torch.ones(5, 10))
767
        self.assertTrue(net.linear_1.weight.shape == (15, 10))
768
        self.assertTrue(net.linear_1.bias.shape == (15,))
769
        self.assertTrue(net.linear_2.weight.shape == (10, 15))
770
        self.assertTrue(net.linear_2.bias.shape == (10,))
771

772
    @suppress_warnings
773
    def test_optimizer_pass(self):
774
        optimizers = [
775
            torch.optim.Adadelta,
776
            torch.optim.Adagrad,
777
            torch.optim.Adamax,
778
            torch.optim.Adam,
779
            torch.optim.AdamW,
780
            torch.optim.ASGD,
781
            torch.optim.SGD,
782
            torch.optim.Rprop,
783
            torch.optim.RMSprop,
784
            torch.optim.LBFGS,
785
            torch.optim.NAdam,
786
            torch.optim.RAdam,
787
        ]
788

789
        def run_step(module, optim):
790
            self.assertIsInstance(
791
                optim.param_groups[0]["params"][0], UninitializedParameter
792
            )
793
            module.test_param.materialize(10)
794
            self.assertIsInstance(optim.param_groups[0]["params"][0], Parameter)
795
            self.assertNotIsInstance(
796
                optim.param_groups[0]["params"][0], UninitializedParameter
797
            )
798
            for p in module.parameters():
799
                p.grad = torch.rand_like(p)
800
            if isinstance(optim, torch.optim.LBFGS):
801
                optim.step(lambda: 1.0)
802
            else:
803
                optim.step()
804

805
        for optim_cls in optimizers:
806
            module = LazyModule()
807
            module.register_parameter("test_param", UninitializedParameter())
808
            if optim_cls is torch.optim.SGD:
809
                optim = optim_cls(module.parameters(), lr=0.0)
810
            elif optim_cls is torch.optim.Adagrad:
811
                with self.assertRaisesRegex(ValueError, "uninitialized parameter"):
812
                    optim = optim_cls(module.parameters())
813
                continue
814
            else:
815
                optim = optim_cls(module.parameters())
816
            run_step(module, optim)
817

818
    @suppress_warnings
819
    def test_weight_norm(self):
820
        m = nn.LazyLinear(7)
821
        with self.assertRaisesRegex(ValueError, "have uninitialized parameters."):
822
            m = torch.nn.utils.weight_norm(m)
823

824
    @suppress_warnings
825
    def test_spectral_norm(self):
826
        m = nn.LazyLinear(7)
827
        with self.assertRaisesRegex(ValueError, "have uninitialized parameters."):
828
            m = torch.nn.utils.spectral_norm(m)
829

830
    @suppress_warnings
831
    def test_invalid_functions(self):
832
        param = torch.nn.parameter.UninitializedParameter()
833
        with self.assertRaisesRegex(ValueError, "uninitialized parameter"):
834
            torch.empty_like(param)
835

836
        with self.assertRaisesRegex(ValueError, "uninitialized parameter"):
837
            torch.add(param, param)
838

839
        with self.assertRaisesRegex(ValueError, "uninitialized parameter"):
840
            param + param
841

842

843
if __name__ == "__main__":
844
    run_tests()
845

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.