pytorch

Форк
0
/
test_module_containers.py 
758 строк · 25.3 Кб
1
# Owner(s): ["oncall: jit"]
2

3
import os
4
import sys
5
from collections import OrderedDict
6
from typing import Any, List, Tuple
7

8
import torch
9
import torch.nn as nn
10
from torch.testing._internal.jit_utils import JitTestCase
11

12

13
# Make the helper files in test/ importable
14
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
15
sys.path.append(pytorch_test_dir)
16

17
if __name__ == "__main__":
18
    raise RuntimeError(
19
        "This test file is not meant to be run directly, use:\n\n"
20
        "\tpython test/test_jit.py TESTNAME\n\n"
21
        "instead."
22
    )
23

24

25
class TestModuleContainers(JitTestCase):
26
    def test_sequential_intermediary_types(self):
27
        class A(torch.nn.Module):
28
            def forward(self, x):
29
                return x + 3
30

31
        class B(torch.nn.Module):
32
            def forward(self, x):
33
                return {"1": x}
34

35
        class C(torch.nn.Module):
36
            def __init__(self) -> None:
37
                super().__init__()
38
                self.foo = torch.nn.Sequential(A(), B())
39

40
            def forward(self, x):
41
                return self.foo(x)
42

43
        self.checkModule(C(), (torch.tensor(1),))
44

45
    def test_moduledict(self):
46
        class Inner(torch.nn.Module):
47
            def forward(self, x):
48
                return x + 10
49

50
        class Inner2(torch.nn.Module):
51
            def forward(self, x):
52
                return x * 2
53

54
        class Inner3(torch.nn.Module):
55
            def forward(self, x):
56
                return (x - 4) * 3
57

58
        class M(torch.nn.Module):
59
            def __init__(self) -> None:
60
                super().__init__()
61
                modules = OrderedDict(
62
                    [
63
                        ("one", Inner()),
64
                        ("two", Inner2()),
65
                        ("three", Inner3()),
66
                    ]
67
                )
68
                self.moduledict = nn.ModuleDict(modules)
69

70
            def forward(self, x, skip_name):
71
                # type: (Tensor, str)
72
                names = torch.jit.annotate(List[str], [])
73
                values = []
74
                for name in self.moduledict:
75
                    names.append(name)
76

77
                for name, mod in self.moduledict.items():
78
                    if name != skip_name:
79
                        names.append(name)
80
                        x = mod(x)
81
                        values.append(x)
82

83
                for mod in self.moduledict.values():
84
                    x = mod(x)
85
                    values.append(x)
86

87
                for key in self.moduledict.keys():
88
                    names.append(key)
89

90
                return x, names
91

92
        class M2(M):
93
            def forward(self, x, skip_name):
94
                # type: (Tensor, str)
95
                names = torch.jit.annotate(List[str], [])
96
                values = []
97
                x2 = x
98
                iter = 0
99
                for name in self.moduledict:
100
                    names.append(name)
101

102
                for i, (name, mod) in enumerate(self.moduledict.items()):
103
                    iter += i
104
                    if name != skip_name:
105
                        names.append(name)
106
                        x = mod(x)
107
                        values.append(x)
108

109
                for i, mod in enumerate(self.moduledict.values()):
110
                    iter += i
111
                    x = mod(x)
112
                    values.append(x)
113

114
                for i, key in enumerate(self.moduledict.keys()):
115
                    iter += i
116
                    names.append(key)
117

118
                for mod, mod in zip(self.moduledict.values(), self.moduledict.values()):
119
                    iter += i
120
                    x2 = mod(mod(x2))
121

122
                return x, x2, names, iter
123

124
        for name in ["", "one", "two", "three"]:
125
            inp = torch.tensor(1)
126
            self.checkModule(M(), (inp, name))
127
            self.checkModule(M2(), (inp, name))
128

129
    def test_custom_container_forward(self):
130
        class Inner(torch.nn.Module):
131
            def forward(self, x):
132
                return x + 10
133

134
        class CustomSequential(nn.Sequential):
135
            def __init__(self) -> None:
136
                super().__init__(nn.ReLU(), Inner())
137

138
            def forward(self, x):
139
                x = x + 3
140
                for mod in self:
141
                    x = mod(x)
142
                return x - 5
143

144
        self.checkModule(CustomSequential(), (torch.tensor(0.5),))
145

146
        class CustomModuleList(nn.ModuleList):
147
            def __init__(self) -> None:
148
                super().__init__([nn.ReLU(), Inner()])
149

150
            def forward(self, x):
151
                x = x + 3
152
                for mod in self:
153
                    x = mod(x)
154
                return x - 5
155

156
        self.checkModule(CustomModuleList(), (torch.tensor(0.5),))
157

158
        class CustomModuleDict(nn.ModuleDict):
159
            def __init__(self) -> None:
160
                super().__init__(
161
                    OrderedDict(
162
                        [
163
                            ("one", Inner()),
164
                            ("two", nn.ReLU()),
165
                            ("three", Inner()),
166
                        ]
167
                    )
168
                )
169

170
            def forward(self, x):
171
                x = x + 3
172
                names = torch.jit.annotate(List[str], [])
173
                for name, mod in self.items():
174
                    x = mod(x)
175
                    names.append(name)
176
                return names, x - 5
177

178
        self.checkModule(CustomModuleDict(), (torch.tensor(0.5),))
179

180
    def test_script_module_list_sequential(self):
181
        class M(torch.jit.ScriptModule):
182
            def __init__(self, mod_list):
183
                super().__init__()
184
                self.mods = mod_list
185

186
            @torch.jit.script_method
187
            def forward(self, v):
188
                for m in self.mods:
189
                    v = m(v)
190
                return v
191

192
        with torch.jit.optimized_execution(False):
193
            m = M(nn.Sequential(nn.ReLU()))
194
            self.assertExportImportModule(m, (torch.randn(2, 2),))
195

196
    def test_script_modulelist_index(self):
197
        class Sub(torch.nn.Module):
198
            def __init__(self, i):
199
                super().__init__()
200
                self.i = i
201

202
            def forward(self, thing):
203
                return thing - self.i
204

205
        class M(torch.nn.Module):
206
            def __init__(self) -> None:
207
                super().__init__()
208
                self.mods = nn.ModuleList([Sub(i) for i in range(10)])
209

210
            def forward(self, v):
211
                v = self.mods[4].forward(v)
212
                v = self.mods[-1].forward(v)
213
                v = self.mods[-9].forward(v)
214
                return v
215

216
        x = torch.tensor(1)
217
        self.checkModule(M(), (x,))
218

219
        class MForward(torch.nn.Module):
220
            def __init__(self) -> None:
221
                super().__init__()
222
                self.mods = nn.ModuleList([Sub(i) for i in range(10)])
223

224
            def forward(self, v):
225
                v = self.mods[4](v)
226
                v = self.mods[-1](v)
227
                v = self.mods[-9](v)
228
                return v
229

230
        self.checkModule(MForward(), (torch.tensor(1),))
231

232
        class M2(M):
233
            def forward(self, v):
234
                return self.mods[-11].forward(v)
235

236
        with self.assertRaisesRegexWithHighlight(
237
            Exception, "Index -11 out of range", "self.mods[-11]"
238
        ):
239
            torch.jit.script(M2())
240

241
        class M3(M):
242
            def forward(self, v):
243
                i = 3
244
                return self.mods[i].forward(v)
245

246
        with self.assertRaisesRegexWithHighlight(
247
            Exception, "Enumeration is supported", "self.mods[i]"
248
        ):
249
            torch.jit.script(M3())
250

251
        class M4(M):
252
            def forward(self, v):
253
                i = 3
254
                return self.mods[i].forward(v)
255

256
        with self.assertRaisesRegex(Exception, "will fail because i is not a literal"):
257
            torch.jit.script(M4())
258

259
    def test_module_interface_special_methods(self):
260
        class CustomModuleInterface(torch.nn.Module):
261
            pass
262

263
        class CustomModuleList(CustomModuleInterface, torch.nn.ModuleList):
264
            def __init__(self, modules=None):
265
                CustomModuleInterface.__init__(self)
266
                torch.nn.ModuleList.__init__(self, modules)
267

268
        class CustomSequential(CustomModuleInterface, torch.nn.Sequential):
269
            def __init__(self, modules=None):
270
                CustomModuleInterface.__init__(self)
271
                torch.nn.Sequential.__init__(self, modules)
272

273
        class CustomModuleDict(CustomModuleInterface, torch.nn.ModuleDict):
274
            def __init__(self, modules=None):
275
                CustomModuleInterface.__init__(self)
276
                torch.nn.ModuleDict.__init__(self, modules)
277

278
        class MyModule(torch.nn.Module):
279
            def __init__(self) -> None:
280
                super().__init__()
281
                # work around aliasing issue for 'is' operator by scripting ReLU up front
282
                self.submod = torch.jit.script(torch.nn.ReLU())
283
                self.modulelist = CustomModuleList([self.submod])
284
                self.sequential = CustomSequential(self.submod)
285
                self.moduledict = CustomModuleDict({"submod": self.submod})
286

287
            def forward(self, inputs):
288
                assert (
289
                    self.modulelist[0] is self.submod
290
                ), "__getitem__ failing for ModuleList"
291
                assert len(self.modulelist) == 1, "__len__ failing for ModuleList"
292
                for module in self.modulelist:
293
                    assert module is self.submod, "__iter__ failing for ModuleList"
294

295
                assert (
296
                    self.sequential[0] is self.submod
297
                ), "__getitem__ failing for Sequential"
298
                assert len(self.sequential) == 1, "__len__ failing for Sequential"
299
                for module in self.sequential:
300
                    assert module is self.submod, "__iter__ failing for Sequential"
301

302
                assert (
303
                    self.moduledict["submod"] is self.submod
304
                ), "__getitem__ failing for ModuleDict"
305
                assert len(self.moduledict) == 1, "__len__ failing for ModuleDict"
306

307
                # note: unable to index moduledict with a string variable currently
308
                i = 0
309
                for key in self.moduledict:
310
                    i += 1
311
                assert i == len(self.moduledict), "iteration failing for ModuleDict"
312

313
                assert "submod" in self.moduledict, "__contains__ fails for ModuleDict"
314

315
                for key in self.moduledict.keys():
316
                    assert key == "submod", "keys() fails for ModuleDict"
317

318
                for item in self.moduledict.items():
319
                    assert item[0] == "submod", "items() fails for ModuleDict"
320
                    assert item[1] is self.submod, "items() fails for ModuleDict"
321

322
                for value in self.moduledict.values():
323
                    assert value is self.submod, "values() fails for ModuleDict"
324

325
                return inputs
326

327
        m = MyModule()
328
        self.checkModule(m, [torch.randn(2, 2)])
329

330
    def test_special_method_with_override(self):
331
        class CustomModuleInterface(torch.nn.Module):
332
            pass
333

334
        class CustomModuleList(CustomModuleInterface, torch.nn.ModuleList):
335
            def __init__(self, modules=None):
336
                CustomModuleInterface.__init__(self)
337
                torch.nn.ModuleList.__init__(self, modules)
338

339
            def __len__(self):
340
                # this is arbitrary, just to check that the overridden py __len__ from
341
                # CustomModuleList takes precedence over the automatically generated
342
                # __len__ added by the jit compiler
343
                return 2
344

345
        class MyModule(torch.nn.Module):
346
            def __init__(self) -> None:
347
                super().__init__()
348
                # work around aliasing issue for 'is' operator by scripting ReLU up front
349
                self.submod = torch.jit.script(torch.nn.ReLU())
350
                self.modulelist = CustomModuleList([self.submod])
351

352
            def forward(self, inputs):
353
                assert len(self.modulelist) == 2, "__len__ failing for ModuleList"
354
                return inputs
355

356
        m = MyModule()
357
        self.checkModule(m, [torch.randn(2, 2)])
358
        mm = torch.jit.script(m)
359

360
    def test_moduledict_getitem(self):
361
        class MyModule(torch.nn.Module):
362
            def __init__(self) -> None:
363
                super().__init__()
364
                self.relu = torch.jit.script(torch.nn.ReLU())
365
                self.tanh = torch.jit.script(torch.nn.Tanh())
366
                self.moduledict = torch.nn.ModuleDict(
367
                    {"relu": self.relu, "tanh": self.tanh}
368
                )
369

370
            def forward(self, input):
371
                assert self.moduledict["relu"] is self.relu
372
                assert self.moduledict["tanh"] is self.tanh
373
                return input
374

375
        m = MyModule()
376
        self.checkModule(m, [torch.randn(2, 2)])
377

378
    def test_moduledict_keyerror(self):
379
        class BadModule(torch.nn.Module):
380
            def __init__(self) -> None:
381
                super().__init__()
382
                self.moduledict = torch.nn.ModuleDict({"foo": None, "bar": None})
383

384
            def forward(self, input):
385
                assert self.moduledict["blah"] == "blah", "this is a keyerror"
386

387
        with self.assertRaisesRegexWithHighlight(
388
            RuntimeError, "Key Error, blah", 'self.moduledict["blah"'
389
        ):
390
            b = BadModule()
391
            torch.jit.script(b)
392

393
        class AnotherBadModule(torch.nn.Module):
394
            def __init__(self) -> None:
395
                super().__init__()
396
                self.moduledict = torch.nn.ModuleDict({"foo": None, "bar": None})
397

398
            def forward(self, input):
399
                idx = "blah"
400
                assert self.moduledict[idx] == "blah", "this is a string literal error"
401

402
        with self.assertRaisesRegexWithHighlight(
403
            RuntimeError,
404
            "Unable to extract string literal index. "
405
            "ModuleDict indexing is only supported with string literals. "
406
            "For example, 'i = \"a\"; self.layers\\[i\\]\\(x\\)' will fail "
407
            "because i is not a literal.",
408
            "self.moduledict[idx]",
409
        ):
410
            b = AnotherBadModule()
411
            torch.jit.script(b)
412

413
    def test_normal_list_attribute_with_modules_error(self):
414
        """
415
        Test that an attempt to script a module with a regular list attribute
416
        containing other modules fails with a relevant error message.
417
        """
418

419
        class Mod(torch.nn.Module):
420
            def __init__(self) -> None:
421
                super().__init__()
422
                self.a = [torch.nn.ReLU(), torch.nn.ReLU()]
423

424
            def forward(self):
425
                return len(self.a)
426

427
        error_msg = "Could not infer type of list element: Cannot infer concrete type of torch.nn.Module"
428
        with self.assertRaisesRegexWithHighlight(RuntimeError, error_msg, "self.a"):
429
            torch.jit.script(Mod())
430

431
    def test_empty_dict_override_contains(self):
432
        class CustomModuleInterface(torch.nn.Module):
433
            pass
434

435
        class CustomModuleDict(CustomModuleInterface, torch.nn.ModuleDict):
436
            def __init__(self, modules=None):
437
                CustomModuleInterface.__init__(self)
438
                torch.nn.ModuleDict.__init__(self, modules)
439

440
        class MyModule(torch.nn.Module):
441
            def __init__(self) -> None:
442
                super().__init__()
443
                # work around aliasing issue for 'is' operator by scripting ReLU up front
444
                self.submod = torch.jit.script(torch.nn.ReLU())
445
                self.moduledict = CustomModuleDict()
446

447
            def forward(self, inputs):
448
                assert (
449
                    "submod" not in self.moduledict
450
                ), "__contains__ fails for ModuleDict"
451
                return inputs
452

453
        m = MyModule()
454
        self.checkModule(m, [torch.randn(2, 2)])
455

456
    def test_typed_module_dict(self):
457
        """
458
        Test that a type annotation can be provided for a ModuleDict that allows
459
        non-static indexing.
460
        """
461

462
        @torch.jit.interface
463
        class ModuleInterface(torch.nn.Module):
464
            def forward(self, inp: Any) -> Any:
465
                pass
466

467
        class ImplementsInterface(torch.nn.Module):
468
            def forward(self, inp: Any) -> Any:
469
                if isinstance(inp, torch.Tensor):
470
                    return torch.max(inp, dim=0)
471

472
                return inp
473

474
        class DoesNotImplementInterface(torch.nn.Module):
475
            def forward(self, inp: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
476
                return torch.max(inp, dim=0)
477

478
        # Test annotation of submodule.
479
        class Mod(torch.nn.Module):
480
            def __init__(self) -> None:
481
                super().__init__()
482
                self.d = torch.nn.ModuleDict({"module": ImplementsInterface()})
483

484
            def forward(self, x: torch.Tensor, key: str) -> Any:
485
                value: ModuleInterface = self.d[key]
486
                return value.forward(x)
487

488
        m = Mod()
489
        self.checkModule(m, (torch.randn(2, 2), "module"))
490

491
        # Test annotation of self.
492
        class ModDict(torch.nn.ModuleDict):
493
            def __init__(self) -> None:
494
                super().__init__({"module": ImplementsInterface()})
495

496
            def forward(self, x: torch.Tensor, key: str) -> Any:
497
                submodule: ModuleInterface = self[key]
498
                return submodule.forward(x)
499

500
        m = ModDict()
501
        self.checkModule(m, (torch.randn(2, 2), "module"))
502

503
        # Test error message thrown when annotated attribute does not comply with the
504
        # annotation.
505
        class ModWithWrongAnnotation(torch.nn.ModuleDict):
506
            def __init__(self) -> None:
507
                super().__init__()
508
                self.d = torch.nn.ModuleDict({"module": DoesNotImplementInterface()})
509

510
            def forward(self, x: torch.Tensor, key: str) -> Any:
511
                submodule: ModuleInterface = self.d[key]
512
                return submodule.forward(x)
513

514
        with self.assertRaisesRegexWithHighlight(
515
            RuntimeError, r"Attribute module is not of annotated type", "self.d[key]"
516
        ):
517
            torch.jit.script(ModWithWrongAnnotation())
518

519
    def test_typed_module_list(self):
520
        """
521
        Test that a type annotation can be provided for a ModuleList that allows
522
        non-static indexing.
523
        """
524

525
        @torch.jit.interface
526
        class ModuleInterface(torch.nn.Module):
527
            def forward(self, inp: Any) -> Any:
528
                pass
529

530
        class ImplementsInterface(torch.nn.Module):
531
            def forward(self, inp: Any) -> Any:
532
                if isinstance(inp, torch.Tensor):
533
                    return torch.max(inp, dim=0)
534

535
                return inp
536

537
        class DoesNotImplementInterface(torch.nn.Module):
538
            def forward(self, inp: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
539
                return torch.max(inp, dim=0)
540

541
        # Test annotation of submodule.
542
        class Mod(torch.nn.Module):
543
            def __init__(self) -> None:
544
                super().__init__()
545
                self.l = torch.nn.ModuleList([ImplementsInterface()])
546

547
            def forward(self, x: torch.Tensor, idx: int) -> Any:
548
                value: ModuleInterface = self.l[idx]
549
                return value.forward(x)
550

551
        m = Mod()
552
        self.checkModule(m, (torch.randn(2, 2), 0))
553

554
        # Test annotation of self.
555
        class ModList(torch.nn.ModuleList):
556
            def __init__(self) -> None:
557
                super().__init__([ImplementsInterface()])
558

559
            def forward(self, x: torch.Tensor, idx: int) -> Any:
560
                submodule: ModuleInterface = self[idx]
561
                return submodule.forward(x)
562

563
        m = ModList()
564
        self.checkModule(m, (torch.randn(2, 2), 0))
565

566
        # Test error message thrown when annotated attribute does not comply with the
567
        # annotation.
568
        class ModWithWrongAnnotation(torch.nn.ModuleList):
569
            def __init__(self) -> None:
570
                super().__init__()
571
                self.l = torch.nn.ModuleList([DoesNotImplementInterface()])
572

573
            def forward(self, x: torch.Tensor, idx: int) -> Any:
574
                submodule: ModuleInterface = self.l[idx]
575
                return submodule.forward(x)
576

577
        with self.assertRaisesRegexWithHighlight(
578
            RuntimeError, r"Attribute 0 is not of annotated type", "self.l[idx]"
579
        ):
580
            torch.jit.script(ModWithWrongAnnotation())
581

582
    def test_module_properties(self):
583
        class ModuleWithProperties(torch.nn.Module):
584
            __jit_unused_properties__ = ["ignored_attr"]
585

586
            def __init__(self, a: int):
587
                super().__init__()
588
                self.a = a
589

590
            def forward(self, a: int, b: int):
591
                self.attr = a + b
592
                return self.attr
593

594
            @property
595
            def attr(self):
596
                return self.a
597

598
            @property
599
            def ignored_attr(self):
600
                return sum([self.a])
601

602
            @torch.jit.unused
603
            @property
604
            def ignored_attr_2(self):
605
                return sum([self.a])
606

607
            @ignored_attr_2.setter
608
            def ignored_attr_2(self, value):
609
                self.a = sum([self.a])
610

611
            @attr.setter
612
            def attr(self, a: int):
613
                if a > 0:
614
                    self.a = a
615
                else:
616
                    self.a = 0
617

618
        class ModuleWithNoSetter(torch.nn.Module):
619
            def __init__(self, a: int):
620
                super().__init__()
621
                self.a = a
622

623
            def forward(self, a: int, b: int):
624
                self.attr + a + b
625

626
            @property
627
            def attr(self):
628
                return self.a + 1
629

630
        self.checkModule(
631
            ModuleWithProperties(5),
632
            (
633
                5,
634
                6,
635
            ),
636
        )
637
        self.checkModule(
638
            ModuleWithProperties(5),
639
            (
640
                -5,
641
                -6,
642
            ),
643
        )
644
        self.checkModule(
645
            ModuleWithNoSetter(5),
646
            (
647
                5,
648
                6,
649
            ),
650
        )
651
        self.checkModule(
652
            ModuleWithNoSetter(5),
653
            (
654
                -5,
655
                -6,
656
            ),
657
        )
658

659
        mod = ModuleWithProperties(3)
660
        scripted_mod = torch.jit.script(mod)
661

662
        with self.assertRaisesRegex(AttributeError, "has no attribute"):
663
            scripted_mod.ignored_attr
664

665
    def test_module_inplace_construct(self):
666
        class M(nn.Module):
667
            def __init__(self, start: int):
668
                super().__init__()
669
                self.linear = nn.Linear(3, 3)
670
                self.attribute = start
671
                self.parameter = nn.Parameter(torch.tensor(3, dtype=torch.float))
672

673
            def method(self) -> int:
674
                return self.attribute
675

676
            @torch.jit.unused
677
            def unused_method(self):
678
                return self.attribute + self.attribute
679

680
            def forward(self, x):
681
                return self.linear(self.linear(x))
682

683
        class N(nn.Module):
684
            def __init__(self) -> None:
685
                super().__init__()
686
                self.linear = nn.Linear(4, 4)
687

688
            @torch.jit.ignore
689
            def ignored_method(self, x):
690
                return x
691

692
            def forward(self, x):
693
                return self.linear(x)
694

695
        m = torch.jit.script(M(3))
696
        n = torch.jit.script(N())
697

698
        n._reconstruct(m._c)
699

700
        inp = torch.rand((3))
701

702
        # Check that both modules produce the same output.
703
        with torch.no_grad():
704
            m_out = m(inp)
705
            n_out = n(inp)
706
            self.assertEqual(m_out, n_out)
707

708
        # Check that ignored method is still intact.
709
        self.assertEqual(inp, n.ignored_method(inp))
710

711
    def test_parameterlist_script_getitem(self):
712
        class MyModule(nn.Module):
713
            def __init__(self) -> None:
714
                super().__init__()
715
                self.module_list = nn.ModuleList([nn.Linear(1, 1) for _ in range(10)])
716
                self.parameter_list = nn.ParameterList(
717
                    [nn.Parameter(torch.zeros(1)) for _ in range(10)]
718
                )
719

720
            def forward(self, x):
721
                self.module_list[0]
722
                self.parameter_list[0]
723
                return x
724

725
        self.checkModule(MyModule(), (torch.zeros(1)))
726

727
    def test_parameterlist_script_iter(self):
728
        class MyModule(nn.Module):
729
            def __init__(self) -> None:
730
                super().__init__()
731
                self.module_list = nn.ModuleList([nn.Linear(1, 1) for _ in range(10)])
732
                self.parameter_list = nn.ParameterList(
733
                    [nn.Parameter(torch.zeros(1)) for _ in range(10)]
734
                )
735

736
            def forward(self, x):
737
                r = x
738
                for i, p in enumerate(self.parameter_list):
739
                    r = r + p + i
740
                return r
741

742
        self.checkModule(MyModule(), (torch.zeros(1),))
743

744
    def test_parameterdict_script_getitem(self):
745
        class MyModule(nn.Module):
746
            def __init__(self) -> None:
747
                super().__init__()
748
                self.parameter_dict = nn.ParameterDict(
749
                    {k: nn.Parameter(torch.zeros(1)) for k in ["a", "b", "c"]}
750
                )
751

752
            def forward(self, x):
753
                return (
754
                    self.parameter_dict["a"] * x
755
                    + self.parameter_dict["b"] * self.parameter_dict["c"]
756
                )
757

758
        self.checkModule(MyModule(), (torch.ones(1),))
759

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.