1
# Owner(s): ["module: nn"]
7
from torch.nn import Buffer, Parameter
8
from torch.nn.parameter import UninitializedBuffer, UninitializedParameter
9
from torch.testing._internal.common_cuda import TEST_CUDA
10
from torch.testing._internal.common_utils import (
18
class LazyModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
22
class TestLazyModules(TestCase):
24
def test_lazy_module_parameter(self):
26
module.register_parameter("test_param", UninitializedParameter())
27
self.assertTrue(module.has_uninitialized_params())
28
state_dict = module.state_dict()
29
self.assertIsInstance(state_dict["test_param"], UninitializedParameter)
30
new_module = LazyModule()
31
# An error is raised when there is an attempt to replace an existing parameter
32
# with an uninitialized one
33
new_module.register_parameter("test_param", nn.Parameter(torch.ones(5, 5)))
34
with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"):
35
new_module.load_state_dict(state_dict)
36
# Uninitialized parameters are overriden when the state dict to be loaded contains a valid one
37
new_module = LazyModule()
38
new_module.register_parameter("test_param", nn.Parameter(torch.ones(5, 5)))
39
module.load_state_dict(new_module.state_dict())
40
self.assertEqual(module.test_param, torch.ones((5, 5)))
42
# Uninitialized parameters are left unchanged
44
module.register_parameter("test_param", UninitializedParameter())
45
self.assertTrue(module.has_uninitialized_params())
47
new_module = LazyModule()
48
new_module.register_parameter("test_param", UninitializedParameter())
49
module.load_state_dict(new_module.state_dict())
50
self.assertTrue(module.has_uninitialized_params())
53
def test_lazy_module_buffer(self):
55
module.test_buffer = UninitializedBuffer()
56
self.assertTrue(module.has_uninitialized_params())
57
state_dict = module.state_dict()
58
self.assertIsInstance(state_dict["test_buffer"], UninitializedBuffer)
59
new_module = LazyModule()
60
# An error is raised when there is an attempt to replace an existing parameter
61
# with an uninitialized one
62
new_module.test_buffer = Buffer(torch.ones(5, 5))
63
with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"):
64
new_module.load_state_dict(state_dict)
65
# Uninitialized parameters are overriden when the state dict to be loaded contains a valid one
66
new_module = LazyModule()
67
new_module.test_buffer = Buffer(torch.ones(5, 5))
68
module.load_state_dict(new_module.state_dict())
69
self.assertEqual(module.test_buffer, torch.ones((5, 5)))
71
# Uninitialized parameters are left unchanged
73
module.test_buffer = UninitializedBuffer()
74
self.assertTrue(module.has_uninitialized_params())
76
new_module = LazyModule()
77
new_module.test_buffer = UninitializedBuffer()
78
module.load_state_dict(new_module.state_dict())
79
module.load_state_dict(new_module.state_dict())
80
self.assertTrue(module.has_uninitialized_params())
83
def test_lazy_module_jit_param(self):
85
module.register_parameter("test_param", UninitializedParameter())
86
self.assertTrue(module.has_uninitialized_params())
87
with self.assertRaisesRegex(RuntimeError, "run a forward pass"):
88
torch.jit.script(module)
91
def test_lazy_module_jit_buffer(self):
93
module.test_buffer = UninitializedBuffer()
94
self.assertTrue(module.has_uninitialized_params())
95
with self.assertRaisesRegex(RuntimeError, "run a forward pass"):
96
torch.jit.script(module)
99
def test_lazy_share_memory_param(self):
100
module = LazyModule()
101
module.register_parameter("test_param", UninitializedParameter())
102
self.assertTrue(module.has_uninitialized_params())
103
with self.assertRaisesRegex(RuntimeError, "share memory on an uninitialized"):
104
module.share_memory()
107
def test_lazy_share_memory_buffer(self):
108
module = LazyModule()
109
module.test_buffer = UninitializedBuffer()
110
self.assertTrue(module.has_uninitialized_params())
111
with self.assertRaisesRegex(RuntimeError, "share memory on an uninitialized"):
112
module.share_memory()
115
def test_linear(self):
116
module = nn.LazyLinear(10)
117
self.assertIsInstance(module.weight, UninitializedParameter)
118
self.assertIsInstance(module.bias, UninitializedParameter)
119
input = torch.ones(5, 5)
121
self.assertIsInstance(module, nn.Linear)
122
self.assertNotIsInstance(module, nn.LazyLinear)
123
self.assertTrue(module.weight.shape == (10, 5))
124
self.assertTrue(module.bias.shape == (10,))
128
torch.nn.functional.linear(input, module.weight, module.bias), y
133
def test_lazy_linear_pickle(self):
134
module = nn.LazyLinear(10)
135
self.assertIsInstance(module.weight, UninitializedParameter)
136
self.assertIsInstance(module.bias, UninitializedParameter)
137
module = pickle.loads(pickle.dumps(module))
138
self.assertIsInstance(module, nn.LazyLinear)
139
self.assertIsInstance(module.weight, UninitializedParameter)
140
self.assertIsInstance(module.bias, UninitializedParameter)
141
input = torch.ones(5, 5)
142
module(input) # fully materialized
143
new_module = pickle.loads(pickle.dumps(module))
144
self.assertIsInstance(new_module, nn.Linear)
145
self.assertNotIsInstance(new_module, nn.LazyLinear)
146
self.assertTrue(new_module.weight.shape == (10, 5))
147
self.assertNotIsInstance(new_module.weight, UninitializedParameter)
148
self.assertTrue(new_module.bias.shape == (10,))
149
self.assertNotIsInstance(new_module.bias, UninitializedParameter)
152
def test_linear_state(self):
153
module = nn.Linear(5, 10)
154
lazy_module = nn.LazyLinear(10)
155
lazy_module.load_state_dict(module.state_dict())
156
# Parameters have been initialized but the module won't become a full
157
# Linear one until the first iteration. This is due to
158
# limitations on the state_dict loading logic
159
self.assertFalse(lazy_module.has_uninitialized_params())
160
self.assertTrue(lazy_module.weight.shape == (10, 5))
161
self.assertTrue(lazy_module.bias.shape == (10,))
163
module = nn.Linear(5, 10)
164
lazy_module = nn.LazyLinear(10)
165
with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"):
166
module.load_state_dict(lazy_module.state_dict())
168
def _check_lazy_conv(
175
expected_weight_shape,
180
module = lazy_cls(*init_args)
181
self.assertIsInstance(module.weight, UninitializedParameter)
182
if module.bias is not None:
183
self.assertIsInstance(module.bias, UninitializedParameter)
184
input = torch.ones(*input_shape)
185
module(input, *forward_args, **forward_kwargs)
186
self.assertIsInstance(module, cls)
187
self.assertNotIsInstance(module, lazy_cls)
188
self.assertEqual(module.weight.shape, expected_weight_shape)
189
if module.bias is not None:
190
self.assertEqual(module.bias.shape, expected_bias_shape)
192
self.assertTrue(torch.equal(func(input, module.weight, module.bias), y))
194
def _check_lazy_conv_pickle(
200
expected_weight_shape,
203
module = lazy_cls(*init_args)
204
self.assertIsInstance(module.weight, UninitializedParameter)
205
if module.bias is not None:
206
self.assertIsInstance(module.bias, UninitializedParameter)
207
module = pickle.loads(pickle.dumps(module))
208
self.assertIsInstance(module, lazy_cls)
209
self.assertIsInstance(module.weight, UninitializedParameter)
210
if module.bias is not None:
211
self.assertIsInstance(module.bias, UninitializedParameter)
212
input = torch.ones(*input_shape)
213
module(input) # fully materialized
214
new_module = pickle.loads(pickle.dumps(module))
215
self.assertIsInstance(new_module, cls)
216
self.assertNotIsInstance(new_module, lazy_cls)
217
self.assertEqual(new_module.weight.shape, expected_weight_shape)
218
self.assertNotIsInstance(new_module.weight, UninitializedParameter)
219
if new_module.bias is not None:
220
self.assertEqual(new_module.bias.shape, expected_bias_shape)
221
self.assertNotIsInstance(new_module.bias, UninitializedParameter)
223
def _check_lazy_conv_state(
224
self, gen_module, gen_lazy_module, expected_weight_shape, expected_bias_shape
226
module = gen_module()
227
lazy_module = gen_lazy_module()
228
lazy_module.load_state_dict(module.state_dict())
229
# Parameters have been initialized but the module won't become a full
230
# Conv one until the first iteration. This is due to
231
# limitations on the state_dict loading logic
232
self.assertFalse(lazy_module.has_uninitialized_params())
233
self.assertEqual(lazy_module.weight.shape, expected_weight_shape)
234
if lazy_module.bias is not None:
235
self.assertEqual(lazy_module.bias.shape, expected_bias_shape)
237
module = gen_module()
238
lazy_module = gen_lazy_module()
239
with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"):
240
module.load_state_dict(lazy_module.state_dict())
242
def test_lazy_pre_forward_hook(self):
244
This test is to test whether lazymodule can register other pre-forward hook
245
functions successfully.
248
class TestModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
249
def initialize_parameters(self, input):
252
def forward(self, input):
255
def hook_function(module, input):
258
module = TestModule()
259
module.register_forward_pre_hook(hook_function)
260
output = module(torch.zeros(2, 2))
261
self.assertEqual(output, torch.ones(2, 2))
263
def test_lazy_forward_hook(self):
265
This test is to test whether lazymodule can register other forward hook
266
functions successfully.
269
class TestModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
270
def initialize_parameters(self, input):
273
def forward(self, input):
276
def hook_function(module, input, output):
279
module = TestModule()
280
module.register_forward_hook(hook_function)
281
output = module(torch.zeros(2, 2))
282
self.assertEqual(output, torch.ones(2, 2))
285
def test_lazy_conv1d(self):
286
self._check_lazy_conv(
289
torch.nn.functional.conv1d,
297
def test_lazy_conv1d_pickle(self):
298
self._check_lazy_conv_pickle(
299
nn.Conv1d, nn.LazyConv1d, (32, 2), (192, 16, 50), (32, 16, 2), (32,)
303
def test_lazy_conv1d_state(self):
304
self._check_lazy_conv_state(
305
lambda: nn.Conv1d(16, 32, 2),
306
lambda: nn.LazyConv1d(32, 2),
312
def test_lazy_conv2d(self):
313
self._check_lazy_conv(
316
torch.nn.functional.conv2d,
324
def test_lazy_conv2d_pickle(self):
325
self._check_lazy_conv_pickle(
326
nn.Conv2d, nn.LazyConv2d, (32, 2), (192, 16, 8, 6), (32, 16, 2, 2), (32,)
330
def test_lazy_conv2d_state(self):
331
self._check_lazy_conv_state(
332
lambda: nn.Conv2d(16, 32, 2),
333
lambda: nn.LazyConv2d(32, 2),
339
def test_lazy_conv3d(self):
340
self._check_lazy_conv(
343
torch.nn.functional.conv3d,
351
def test_lazy_conv3d_pickle(self):
352
self._check_lazy_conv_pickle(
362
def test_lazy_conv3d_state(self):
363
self._check_lazy_conv_state(
364
lambda: nn.Conv3d(16, 32, 2),
365
lambda: nn.LazyConv3d(32, 2),
371
def test_lazy_conv_transposed1d(self):
372
self._check_lazy_conv(
374
nn.LazyConvTranspose1d,
375
torch.nn.functional.conv_transpose1d,
383
def test_lazy_conv_transpose1d_kwargs(self):
384
self._check_lazy_conv(
386
nn.LazyConvTranspose1d,
387
torch.nn.functional.conv_transpose1d,
396
def test_lazy_conv_transpose1d_pickle(self):
397
self._check_lazy_conv_pickle(
399
nn.LazyConvTranspose1d,
407
def test_lazy_conv_transpose1d_state(self):
408
self._check_lazy_conv_state(
409
lambda: nn.ConvTranspose1d(16, 32, 2),
410
lambda: nn.LazyConvTranspose1d(32, 2),
416
def test_lazy_conv_transpose2d(self):
417
self._check_lazy_conv(
419
nn.LazyConvTranspose2d,
420
torch.nn.functional.conv_transpose2d,
428
def test_lazy_conv_transpose2d_kwargs(self):
429
self._check_lazy_conv(
431
nn.LazyConvTranspose2d,
432
torch.nn.functional.conv_transpose2d,
441
def test_lazy_conv_transpose2d_pickle(self):
442
self._check_lazy_conv_pickle(
444
nn.LazyConvTranspose2d,
452
def test_lazy_conv_transpose2d_state(self):
453
self._check_lazy_conv_state(
454
lambda: nn.ConvTranspose2d(16, 32, 2),
455
lambda: nn.LazyConvTranspose2d(32, 2),
461
def test_lazy_conv_transpose3d(self):
462
self._check_lazy_conv(
464
nn.LazyConvTranspose3d,
465
torch.nn.functional.conv_transpose3d,
473
def test_lazy_conv_transpose3d_kwargs(self):
474
self._check_lazy_conv(
476
nn.LazyConvTranspose3d,
477
torch.nn.functional.conv_transpose3d,
482
output_size=(9, 8, 7),
486
def test_lazy_conv_transpose3d_pickle(self):
487
self._check_lazy_conv_pickle(
489
nn.LazyConvTranspose3d,
497
def test_lazy_conv_transpose3d_state(self):
498
self._check_lazy_conv_state(
499
lambda: nn.ConvTranspose3d(16, 32, 2),
500
lambda: nn.LazyConvTranspose3d(32, 2),
505
def _check_lazy_norm(self, cls, lazy_cls, input_shape):
506
for affine in [False, True]:
507
for track_running_stats in [False, True]:
508
lazy_module = lazy_cls(
509
affine=affine, track_running_stats=track_running_stats
513
self.assertIsInstance(lazy_module.weight, UninitializedParameter)
514
self.assertIsInstance(lazy_module.bias, UninitializedParameter)
515
if track_running_stats:
516
self.assertIsInstance(lazy_module.running_mean, UninitializedBuffer)
517
self.assertIsInstance(lazy_module.running_var, UninitializedBuffer)
519
input = torch.ones(*input_shape)
520
lazy_output = lazy_module(input)
521
self.assertIsInstance(lazy_module, cls)
522
self.assertNotIsInstance(lazy_module, lazy_cls)
524
num_features = input_shape[1]
526
num_features, affine=affine, track_running_stats=track_running_stats
528
expected_output = module(input)
530
self.assertEqual(lazy_output, expected_output)
531
if module.weight is not None:
532
self.assertEqual(lazy_module.weight.shape, module.weight.shape)
533
self.assertEqual(lazy_module.weight, module.weight)
534
if module.bias is not None:
535
self.assertEqual(lazy_module.bias.shape, module.bias.shape)
536
self.assertEqual(lazy_module.bias, module.bias)
537
if module.running_mean is not None:
539
lazy_module.running_mean.shape, module.running_mean.shape
541
self.assertEqual(lazy_module.running_mean, module.running_mean)
542
if module.running_var is not None:
544
lazy_module.running_var.shape, module.running_var.shape
546
self.assertEqual(lazy_module.running_var, module.running_var)
547
if module.num_batches_tracked is not None:
549
lazy_module.num_batches_tracked.shape,
550
module.num_batches_tracked.shape,
553
lazy_module.num_batches_tracked, module.num_batches_tracked
556
def _check_lazy_norm_pickle(self, cls, lazy_cls, input_shape):
557
for affine in [False, True]:
558
for track_running_stats in [False, True]:
560
affine=affine, track_running_stats=track_running_stats
562
module = pickle.loads(pickle.dumps(module))
564
self.assertIsInstance(module, lazy_cls)
566
self.assertIsInstance(module.weight, UninitializedParameter)
567
self.assertIsInstance(module.bias, UninitializedParameter)
568
if track_running_stats:
569
self.assertIsInstance(module.running_mean, UninitializedBuffer)
570
self.assertIsInstance(module.running_var, UninitializedBuffer)
572
input = torch.ones(*input_shape)
573
module(input) # fully materialized
574
module = pickle.loads(pickle.dumps(module))
576
self.assertNotIsInstance(module, lazy_cls)
577
self.assertIsInstance(module, cls)
579
self.assertNotIsInstance(module.weight, UninitializedParameter)
580
self.assertNotIsInstance(module.bias, UninitializedParameter)
581
if track_running_stats:
582
self.assertNotIsInstance(module.running_mean, UninitializedBuffer)
583
self.assertNotIsInstance(module.running_var, UninitializedBuffer)
585
def _check_lazy_batchnorm_state(self, cls, lazy_cls):
587
lazy_module = lazy_cls(affine=True, track_running_stats=True)
588
lazy_module.load_state_dict(module.state_dict())
589
# Parameters have been initialized but the module won't become a full
590
# Conv one until the first iteration. This is due to
591
# limitations on the state_dict loading logic
592
self.assertFalse(lazy_module.has_uninitialized_params())
593
self.assertEqual(lazy_module.weight.shape, (10,))
594
self.assertEqual(lazy_module.bias.shape, (10,))
595
self.assertEqual(lazy_module.running_mean.shape, (10,))
596
self.assertEqual(lazy_module.running_var.shape, (10,))
599
lazy_module = lazy_cls()
600
with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"):
601
module.load_state_dict(lazy_module.state_dict())
603
def _check_lazy_instancenorm_state(self, cls, lazy_cls):
604
for affine in [False, True]:
605
for track_running_stats in [False, True]:
606
module = cls(10, affine=affine, track_running_stats=track_running_stats)
607
lazy_module = lazy_cls(
608
affine=affine, track_running_stats=track_running_stats
610
lazy_module.load_state_dict(module.state_dict())
611
# Parameters have been initialized but the module won't become a full
612
# InstanceNorm one until the first iteration. This is due to
613
# limitations on the state_dict loading logic
614
self.assertFalse(lazy_module.has_uninitialized_params())
616
self.assertEqual(lazy_module.weight.shape, (10,))
617
self.assertEqual(lazy_module.bias.shape, (10,))
618
if track_running_stats:
619
self.assertEqual(lazy_module.running_mean.shape, (10,))
620
self.assertEqual(lazy_module.running_var.shape, (10,))
622
module = cls(10, affine=True, track_running_stats=True)
623
lazy_module = lazy_cls(affine=True, track_running_stats=True)
624
with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"):
625
module.load_state_dict(lazy_module.state_dict())
627
def _check_lazy_norm_with_dict_input(self, cls, lazy_cls, input_shape):
628
input = {"input": torch.ones(*input_shape)}
630
lazy_module = lazy_cls()
631
lazy_output = lazy_module(**input)
633
num_features = input_shape[1]
634
module = cls(num_features)
635
expected_output = module(**input)
637
self.assertEqual(lazy_output, expected_output)
639
def test_lazy_batchnorm1d(self):
640
self._check_lazy_norm(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6))
641
self._check_lazy_norm(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 6))
643
def test_lazy_batchnorm1d_pickle(self):
644
self._check_lazy_norm_pickle(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6))
645
self._check_lazy_norm_pickle(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 6))
647
def test_lazy_batchnorm1d_state(self):
648
self._check_lazy_batchnorm_state(nn.BatchNorm1d, nn.LazyBatchNorm1d)
649
self._check_lazy_batchnorm_state(nn.BatchNorm1d, nn.LazyBatchNorm1d)
651
def test_lazy_batchnorm2d(self):
652
self._check_lazy_norm(nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7))
654
def test_lazy_batchnorm2d_pickle(self):
655
self._check_lazy_norm_pickle(nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7))
657
def test_lazy_batchnorm2d_state(self):
658
self._check_lazy_batchnorm_state(nn.BatchNorm2d, nn.LazyBatchNorm2d)
659
self._check_lazy_batchnorm_state(nn.BatchNorm2d, nn.LazyBatchNorm2d)
661
def test_lazy_batchnorm3d(self):
662
self._check_lazy_norm(nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8))
664
def test_lazy_batchnorm3d_pickle(self):
665
self._check_lazy_norm_pickle(
666
nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8)
669
def test_lazy_batchnorm3d_state(self):
670
self._check_lazy_batchnorm_state(nn.BatchNorm3d, nn.LazyBatchNorm3d)
671
self._check_lazy_batchnorm_state(nn.BatchNorm3d, nn.LazyBatchNorm3d)
673
def test_lazy_instancenorm1d(self):
674
self._check_lazy_norm(nn.InstanceNorm1d, nn.LazyInstanceNorm1d, (16, 3, 6))
676
def test_lazy_instancenorm1d_pickle(self):
677
self._check_lazy_norm_pickle(
678
nn.InstanceNorm1d, nn.LazyInstanceNorm1d, (16, 3, 6)
681
def test_lazy_instancenorm1d_state(self):
682
self._check_lazy_instancenorm_state(nn.InstanceNorm1d, nn.LazyInstanceNorm1d)
683
self._check_lazy_instancenorm_state(nn.InstanceNorm1d, nn.LazyInstanceNorm1d)
685
def test_lazy_instancenorm2d(self):
686
self._check_lazy_norm(nn.InstanceNorm2d, nn.LazyInstanceNorm2d, (16, 3, 6, 7))
688
def test_lazy_instancenorm2d_pickle(self):
689
self._check_lazy_norm_pickle(
690
nn.InstanceNorm2d, nn.LazyInstanceNorm2d, (16, 3, 6, 7)
693
def test_lazy_instancenorm2d_state(self):
694
self._check_lazy_instancenorm_state(nn.InstanceNorm2d, nn.LazyInstanceNorm2d)
695
self._check_lazy_instancenorm_state(nn.InstanceNorm2d, nn.LazyInstanceNorm2d)
697
def test_lazy_instancenorm3d(self):
698
self._check_lazy_norm(
699
nn.InstanceNorm3d, nn.LazyInstanceNorm3d, (16, 3, 6, 7, 8)
702
def test_lazy_instancenorm3d_pickle(self):
703
self._check_lazy_norm_pickle(
704
nn.InstanceNorm3d, nn.LazyInstanceNorm3d, (16, 3, 6, 7, 8)
707
def test_lazy_instancenorm3d_state(self):
708
self._check_lazy_instancenorm_state(nn.InstanceNorm3d, nn.LazyInstanceNorm3d)
709
self._check_lazy_instancenorm_state(nn.InstanceNorm3d, nn.LazyInstanceNorm3d)
711
def test_lazy_batchnorm_with_dict_input(self):
712
self._check_lazy_norm_with_dict_input(
713
nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6)
715
self._check_lazy_norm_with_dict_input(
716
nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7)
718
self._check_lazy_norm_with_dict_input(
719
nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8)
723
def test_materialize_dtype(self):
724
module = LazyModule()
725
module.register_parameter("test_param", UninitializedParameter())
726
module.test_param.materialize(10)
727
self.assertTrue(module.test_param.dtype == torch.get_default_dtype())
728
module = LazyModule()
729
module.register_parameter("test_param", UninitializedParameter())
731
module.test_param.materialize(10)
732
self.assertTrue(module.test_param.dtype == torch.float16)
735
not (TEST_CUDA or TEST_PRIVATEUSE1), "CUDA and PRIVATEUSE1 not available"
738
def test_materialize_device(self):
739
module = LazyModule()
740
module.register_parameter("test_param", UninitializedParameter())
741
module.test_param.materialize(10)
742
self.assertTrue(module.test_param.device.type == "cpu")
745
elif TEST_PRIVATEUSE1:
746
device = torch._C._get_privateuse1_backend_name()
747
module = LazyModule()
748
module.register_parameter("test_param", UninitializedParameter())
750
module.test_param.materialize(10)
751
self.assertTrue(module.test_param.device.type == device)
754
def test_chained_initialization(self):
755
class MyNetwork(torch.nn.Module):
756
def __init__(self) -> None:
758
self.linear_1 = torch.nn.LazyLinear(15)
759
self.linear_2 = torch.nn.LazyLinear(10)
761
def forward(self, x):
763
return self.linear_2(y)
766
net(torch.ones(5, 10))
767
self.assertTrue(net.linear_1.weight.shape == (15, 10))
768
self.assertTrue(net.linear_1.bias.shape == (15,))
769
self.assertTrue(net.linear_2.weight.shape == (10, 15))
770
self.assertTrue(net.linear_2.bias.shape == (10,))
773
def test_optimizer_pass(self):
775
torch.optim.Adadelta,
789
def run_step(module, optim):
790
self.assertIsInstance(
791
optim.param_groups[0]["params"][0], UninitializedParameter
793
module.test_param.materialize(10)
794
self.assertIsInstance(optim.param_groups[0]["params"][0], Parameter)
795
self.assertNotIsInstance(
796
optim.param_groups[0]["params"][0], UninitializedParameter
798
for p in module.parameters():
799
p.grad = torch.rand_like(p)
800
if isinstance(optim, torch.optim.LBFGS):
801
optim.step(lambda: 1.0)
805
for optim_cls in optimizers:
806
module = LazyModule()
807
module.register_parameter("test_param", UninitializedParameter())
808
if optim_cls is torch.optim.SGD:
809
optim = optim_cls(module.parameters(), lr=0.0)
810
elif optim_cls is torch.optim.Adagrad:
811
with self.assertRaisesRegex(ValueError, "uninitialized parameter"):
812
optim = optim_cls(module.parameters())
815
optim = optim_cls(module.parameters())
816
run_step(module, optim)
819
def test_weight_norm(self):
821
with self.assertRaisesRegex(ValueError, "have uninitialized parameters."):
822
m = torch.nn.utils.weight_norm(m)
825
def test_spectral_norm(self):
827
with self.assertRaisesRegex(ValueError, "have uninitialized parameters."):
828
m = torch.nn.utils.spectral_norm(m)
831
def test_invalid_functions(self):
832
param = torch.nn.parameter.UninitializedParameter()
833
with self.assertRaisesRegex(ValueError, "uninitialized parameter"):
834
torch.empty_like(param)
836
with self.assertRaisesRegex(ValueError, "uninitialized parameter"):
837
torch.add(param, param)
839
with self.assertRaisesRegex(ValueError, "uninitialized parameter"):
843
if __name__ == "__main__":