pytorch

Форк
0
/
test_static_runtime.py 
606 строк · 21.4 Кб
1
# Owner(s): ["module: unknown"]
2

3
import unittest
4
from typing import Dict, Optional
5

6
import numpy as np
7
import torch
8
from torch import nn
9
from torch.testing._internal.common_utils import TestCase, run_tests
10
from typing import List
11

12
class StaticModule:
13
    def __init__(self, scripted):
14
        # this is an nn.Module
15
        if hasattr(scripted, "_c"):
16
            self.static_module = torch._C._jit_to_static_module(scripted._c)
17
        else:
18
            self.static_module = torch._C._jit_to_static_module(scripted.graph)
19

20
    def __call__(self, *args, **kwargs):
21
        return self.static_module(*args, **kwargs)
22

23
    def benchmark(self, args, kwargs, warmup_runs, main_runs):
24
        self.static_module.benchmark(args, kwargs, warmup_runs, main_runs)
25

26
    def runAsync(self, args, kwargs):
27
        return self.static_module.runAsync(args, kwargs)
28

29
    def benchmark_individual_ops(self, args, kwargs, warmup_runs, main_runs):
30
        return self.static_module.benchmark_individual_ops(
31
            args, kwargs, warmup_runs, main_runs
32
        )
33

34

35
def linear_shim(
36
    input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
37
) -> torch.Tensor:
38
    output = input.matmul(weight.t())
39
    if bias is not None:
40
        output += bias
41
    ret = output
42
    return ret
43

44

45
torch.nn.functional.linear = linear_shim
46

47

48
class MultiHeadAttentionLayer(nn.Module):
49
    def __init__(self, hid_dim, n_heads, dropout, device):
50
        super().__init__()
51
        assert hid_dim % n_heads == 0
52
        self.hid_dim = hid_dim
53
        self.n_heads = n_heads
54
        self.head_dim = hid_dim // n_heads
55
        self.fc_q = nn.Linear(hid_dim, hid_dim)
56
        self.fc_k = nn.Linear(hid_dim, hid_dim)
57
        self.fc_v = nn.Linear(hid_dim, hid_dim)
58
        self.fc_o = nn.Linear(hid_dim, hid_dim)
59
        # self.dropout = nn.Dropout(dropout)
60
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
61

62
    def forward(self, query, key, value, mask):
63
        batch_size = query.shape[0]
64
        Q = self.fc_q(query)
65
        K = self.fc_k(key)
66
        V = self.fc_v(value)
67
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
68
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
69
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
70
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
71
        # energy = energy.masked_fill(mask == 0, -1e10)
72
        attention = torch.softmax(energy, dim=-1)
73
        # x = torch.matmul(self.dropout(attention), V)
74
        x = torch.matmul(attention, V)
75
        x = x.permute(0, 2, 1, 3).contiguous()
76
        x = x.view(batch_size, -1, self.hid_dim)
77
        x = self.fc_o(x)
78
        return x, attention
79

80

81
# Taken from https://github.com/facebookresearch/dlrm/blob/master/dlrm_s_pytorch.py
82
def create_mlp(ln, sigmoid_layer):
83
    layers = nn.ModuleList()
84
    for i in range(0, len(ln) - 1):
85
        n = ln[i]
86
        m = ln[i + 1]
87

88
        LL = nn.Linear(int(n), int(m), bias=True)
89

90
        mean = 0.0  # std_dev = np.sqrt(variance)
91
        std_dev = np.sqrt(2 / (m + n))  # np.sqrt(1 / m) # np.sqrt(1 / n)
92
        W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32)
93
        std_dev = np.sqrt(1 / m)  # np.sqrt(2 / (m + 1))
94
        bt = np.random.normal(mean, std_dev, size=m).astype(np.float32)
95
        LL.weight.data = torch.tensor(W, requires_grad=True)
96
        LL.bias.data = torch.tensor(bt, requires_grad=True)
97
        layers.append(LL)
98

99
        if i == sigmoid_layer:
100
            layers.append(nn.Sigmoid())
101
        else:
102
            layers.append(nn.ReLU())
103

104
    with torch.no_grad():
105
        s = torch.jit.script(torch.nn.Sequential(*layers))
106
    s.eval()
107
    return s
108

109

110
def trivial_graph(a, b, c):
111
    s = torch.tensor([[3, 3], [3, 3]])
112
    return a + b * c + s
113

114
def elementwise_square_addition(input1, input2):
115
    return input1 * input1 + input2 * input2
116

117
def fork_wait_graph1(input1, input2):
118
    fut = torch.jit.fork(elementwise_square_addition, input1, input2)
119
    return torch.jit.wait(fut)
120

121
def fork_wait_graph2(input1, input2):
122
    fut = torch.jit.fork(loop_graph, input1, input2, 5)
123
    return torch.jit.wait(fut)
124

125
"""
126
   graph with multiple fork/wait operations
127
   :param input: torch.tensor input to forked subgraph
128
   :param iters: number of future/wait pairs to be created
129
"""
130
def fork_wait_graph3(input, iters: int):
131
    futures : List[torch.jit.Future[torch.Tensor]] = []
132
    for _ in range(iters):
133
        futures.append(torch.jit.fork(torch.neg, input))
134
    results = []
135
    for future in futures:
136
        results.append(torch.jit.wait(future))
137
    return torch.sum(torch.stack(results))
138

139
"""
140
   graph with multi-level fork/wait operations
141
   :param input: torch.tensor input to forked subgraph
142
   :param num_forks: number of top level forks
143
   :param num_child_forks: number of child forks per parent fork
144
"""
145
def fork_wait_graph4(input, num_forks: int, num_child_forks: int):
146
    futures : List[torch.jit.Future[torch.Tensor]] = []
147
    for _ in range(num_forks):
148
        futures.append(torch.jit.fork(fork_wait_graph3, input, num_child_forks))
149
    results = []
150
    for future in futures:
151
        results.append(torch.jit.wait(future))
152
    return torch.sum(torch.stack(results))
153

154
def add_tensor(input1, input2):
155
    return input1 + input2
156

157
def fork_wait_graph_exception(input1, input2):
158
    fut = torch.jit.fork(add_tensor, input1, input2)
159
    return torch.jit.wait(fut)
160

161
def loop_graph(a, b, iters: int):
162
    c = a + b * 2
163
    for i in range(iters):
164
        c = c + b
165
        c *= 2
166
        c -= a
167
    return c
168

169

170
def output_graph(a, b, c, iters: int):
171
    s = torch.tensor([[3, 3], [3, 3]])
172
    k = a + b * c + s
173
    d: Dict[int, torch.Tensor] = {}
174
    for i in range(iters):
175
        d[i] = k + i
176
    return d
177

178

179
class SubModule(nn.Module):
180
    def __init__(self):
181
        super().__init__()
182
        self.a = 11
183
        self.b = 2
184

185
    def forward(self, x):
186
        return self.a + self.b + x
187

188

189
class SubModule2(nn.Module):
190
    def __init__(self):
191
        super().__init__()
192
        self.a = 12
193
        self.b = 2
194

195
    def forward(self, x):
196
        self.b = 30
197
        return self.a + self.b + x
198

199

200
class TestModule(nn.Module):
201
    def __init__(self):
202
        super().__init__()
203
        self.sub1 = SubModule()
204
        self.sub2 = SubModule2()
205
        self.a = 3
206
        self.b = 4
207

208
    def forward(self, x):
209
        self.b = 20
210
        return self.sub1(x) + self.a + self.b + self.sub2(x)
211

212

213
class TestStaticModule(TestCase):
214

215
    """
216
    Test Case: To test simple fork/wait operation in a graph
217
    fork is called on simple addition operation on input tensors
218
    """
219
    def test_fork_wait_1(self):
220
        inp1 = torch.ones(5, 5)
221
        inp2 = torch.randn(5, 5)
222
        torch_graph = torch.jit.script(fork_wait_graph1)
223
        output_ref = torch_graph(inp1, inp2)
224
        static_runtime_module = StaticModule(torch_graph)
225
        output_test = static_runtime_module(inp1, inp2)
226
        torch.testing.assert_close(output_test, output_ref)
227

228
    """
229
    Test Case: To test simple fork/wait operation with
230
    StaticRuntime runAsync API returning future
231
    """
232
    def test_fork_wait_1_async(self):
233
        inp1 = torch.ones(5, 5)
234
        inp2 = torch.randn(5, 5)
235
        torch_graph = torch.jit.script(fork_wait_graph1)
236
        output_ref = torch_graph(inp1, inp2)
237
        static_runtime_module = StaticModule(torch_graph)
238
        output_test = static_runtime_module.runAsync((inp1, inp2), {})
239
        output_test.wait()
240
        torch.testing.assert_close(output_test.value(), output_ref)
241

242
    """
243
    Test Case: To test fork/wait operation in a graph on
244
    a loop subgraph performing mix of operations
245
    """
246
    def test_fork_wait_2(self):
247
        inp1 = torch.randn(5, 5)
248
        inp2 = torch.randn(5, 5)
249
        torch_graph = torch.jit.script(fork_wait_graph2)
250
        output_ref = torch_graph(inp1, inp2)
251
        static_runtime_module = StaticModule(torch_graph)
252
        output_test = static_runtime_module(inp1, inp2)
253
        torch.testing.assert_close(output_test, output_ref)
254

255
    """
256
    Test Case: To test fork/wait operation on a loop
257
    subgraph with StaticRuntime runAsync API returning future
258
    """
259
    def test_fork_wait_2_async(self):
260
        inp1 = torch.randn(5, 5)
261
        inp2 = torch.randn(5, 5)
262
        torch_graph = torch.jit.script(fork_wait_graph2)
263
        output_ref = torch_graph(inp1, inp2)
264
        static_runtime_module = StaticModule(torch_graph)
265
        output_test = static_runtime_module.runAsync((inp1, inp2), {})
266
        output_test.wait()
267
        torch.testing.assert_close(output_test.value(), output_ref)
268

269
    """
270
    Test Case: To test fork/wait operation in a graph on
271
    having multiple fork/wait operations
272
    """
273
    def test_fork_wait_3(self):
274
        input = torch.ones(3, 3)
275
        num_forks = 10
276
        torch_graph = torch.jit.script(fork_wait_graph3)
277
        output_ref = torch_graph(input, num_forks)
278
        static_runtime_module = StaticModule(torch_graph)
279
        output_test = static_runtime_module(input, num_forks)
280
        torch.testing.assert_close(output_test, output_ref)
281

282
    """
283
    Test Case: To test fork/wait operation in a graph with
284
    multiple fork/wait operations on runAsync API returning future
285
    """
286
    def test_fork_wait_3_async(self):
287
        input = torch.ones(3, 3)
288
        num_forks = 10
289
        torch_graph = torch.jit.script(fork_wait_graph3)
290
        output_ref = torch_graph(input, num_forks)
291
        static_runtime_module = StaticModule(torch_graph)
292
        output_test = static_runtime_module.runAsync((input, num_forks), {})
293
        output_test.wait()
294
        torch.testing.assert_close(output_test.value(), output_ref)
295

296
    """
297
    Test Case: To test fork/wait operation in a graph on
298
    multiple nested fork/wait operations
299
    """
300
    @unittest.skip("Broken test: https://github.com/pytorch/pytorch/issues/109782")
301
    def test_fork_wait_4(self):
302
        input = torch.ones(3, 3)
303
        num_forks = 10
304
        num_child_forks = 10
305
        torch_graph = torch.jit.script(fork_wait_graph4)
306
        static_runtime_module = StaticModule(torch_graph)
307
        output_ref = torch_graph(input, num_forks, num_child_forks)
308
        output_test = static_runtime_module(input, num_forks, num_child_forks)
309
        torch.testing.assert_close(output_test, output_ref)
310

311
    """
312
    Test Case: To test fork/wait operation in a graph with multiple
313
    nested fork/wait operations on runAsync API returning future
314
    """
315
    @unittest.skip("Broken test: https://github.com/pytorch/pytorch/issues/109782")
316
    def test_fork_wait_4_async(self):
317
        input = torch.ones(3, 3)
318
        num_forks = 10
319
        num_child_forks = 10
320
        torch_graph = torch.jit.script(fork_wait_graph4)
321
        static_runtime_module = StaticModule(torch_graph)
322
        output_ref = torch_graph(input, num_forks, num_child_forks)
323
        output_test = static_runtime_module.runAsync(
324
            (input, num_forks, num_child_forks), {})
325
        output_test.wait()
326
        torch.testing.assert_close(output_test.value(), output_ref)
327

328
    """
329
    Test Case: To test exception handling in fork/wait
330
    operation. Add.Tensor op is called for tensors with
331
    non-matching dims on the forked subgraph and the
332
    exception raised by subgraph is set on future returned
333
    by prim::fork to parent graph. Returned exception is
334
    checked for substring expected_error_msg as declared below
335
    """
336
    def test_fork_wait_exception(self):
337
        # incompatible tensors for add due to shape mismatch
338
        input1 = torch.randn(4, 7)
339
        input2 = torch.randn(4, 5)
340
        torch_graph = torch.jit.script(fork_wait_graph_exception)
341
        try:
342
            static_runtime_module = StaticModule(torch_graph)
343
            output_test = static_runtime_module(input1, input2)
344
        except Exception as error:
345
            expected_error_msg = (
346
                "The size of tensor a (7) must match the size "
347
                "of tensor b (5) at non-singleton dimension 1"
348
            )
349
            # test fails if error does not contain expected substr
350
            if str(error).find(expected_error_msg) == -1:
351
                raise RuntimeError(
352
                    "Tried execution of add.Tensors with incompatible shape. "
353
                    "Exception raised by forked runtime execution does "
354
                    f"not contain expected substring: \"{expected_error_msg}\""
355
                ) from error
356

357
    """
358
    Test Case: To test exception handling in fork/wait
359
    operation with runAsync API. Add.Tensor op is called for
360
    tensors with non-matching dims on the forked subgraph
361
    and the exception raised by subgraph is set on future returned
362
    by prim::fork to parent graph. Returned exception is
363
    checked for substring expected_error_msg as declared below
364
    """
365
    def test_fork_wait_exception_async(self):
366
        # incompatible tensors for add due to shape mismatch
367
        input1 = torch.randn(4, 7)
368
        input2 = torch.randn(4, 5)
369
        torch_graph = torch.jit.script(fork_wait_graph_exception)
370
        try:
371
            static_runtime_module = StaticModule(torch_graph)
372
            output_test = static_runtime_module.runAsync(
373
                (input1, input2), {})
374
        except Exception as error:
375
            expected_error_msg = (
376
                "The size of tensor a (7) must match the size "
377
                "of tensor b (5) at non-singleton dimension 1"
378
            )
379
            # test fails if error does not contain expected substr
380
            if str(error).find(expected_error_msg) == -1:
381
                raise RuntimeError(
382
                    "Tried execution of add.Tensors with incompatible shape. "
383
                    "Exception raised by forked runtime execution does "
384
                    f"not contain expected substring: \"{expected_error_msg}\""
385
                ) from error
386

387
    def test_multihead_attention_layer(self):
388
        HID_DIM = 256
389
        QUERY_LEN = 8
390
        BATCH_SIZE = 128
391
        LAYERS = 3
392
        HEADS = 8
393
        DROPOUT = 0.1
394
        device = torch.device("cpu")
395
        attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
396
        with torch.no_grad():
397
            src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
398
        src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
399

400
        attention.eval()
401
        attention = torch.jit.script(attention)
402
        attention.eval()
403
        o_ref = attention(src, src, src, src_mask)
404

405
        attention_a = StaticModule(attention)
406
        o_test = attention_a(src, src, src, src_mask)
407
        o_test_kw = attention_a(src, src, value=src, mask=src_mask)
408

409
        for a, b in zip(o_ref, o_test):
410
            torch.testing.assert_close(a, b)
411

412
        for a, b in zip(o_ref, o_test_kw):
413
            torch.testing.assert_close(a, b)
414

415
    def test_multihead_attention_layer_benchmark(self):
416
        HID_DIM = 256
417
        QUERY_LEN = 8
418
        BATCH_SIZE = 128
419
        LAYERS = 3
420
        HEADS = 8
421
        DROPOUT = 0.1
422
        device = torch.device("cpu")
423
        attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
424
        with torch.no_grad():
425
            src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
426
        src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
427

428
        attention.eval()
429
        attention = torch.jit.script(attention)
430
        attention_a = StaticModule(attention)
431

432
        attention_a.benchmark([src, src, src, src_mask], {}, 2, 2)
433
        metrics = attention_a.benchmark_individual_ops(
434
            [src, src, src, src_mask], {}, 2, 2
435
        )
436

437
    def test_mlp(self):
438
        # Arguments taken from benchmark script, ./bench/dlrm_s_benchmark.sh
439
        ln_bot = [512, 512, 64]
440
        sigmoid_bot = -1
441
        ln_top = [100, 1024, 1024, 1024, 1]
442
        sigmoid_top = 3
443
        bot_l = create_mlp(ln_bot, sigmoid_bot)
444
        bot_l_acc = StaticModule(bot_l)
445
        top_l = create_mlp(ln_top, sigmoid_top)
446
        top_l_acc = StaticModule(top_l)
447
        with torch.no_grad():
448
            bot_inp = torch.randn(2048, 512)  # torch.Size([2048, 512])
449
            top_inp = torch.randn(2048, 100)  # torch.Size([2048, 100])
450
        ref_bot = bot_l(bot_inp)
451
        acc_bot = bot_l_acc(bot_inp)
452
        torch.testing.assert_close(acc_bot, ref_bot)
453
        ref_top = top_l(top_inp)
454
        acc_top = top_l_acc(top_inp)
455
        torch.testing.assert_close(acc_top, ref_top)
456
        for _ in range(5):
457
            with torch.no_grad():
458
                bot_inp = torch.randn(2048, 512)  # torch.Size([2048, 512])
459
                top_inp = torch.randn(2048, 100)  # torch.Size([2048, 100])
460
            ref_bot = bot_l(bot_inp)
461
            acc_bot = bot_l_acc(bot_inp)
462
            torch.testing.assert_close(acc_bot, ref_bot)
463
            ref_top = top_l(top_inp)
464
            acc_top = top_l_acc(top_inp)
465
            torch.testing.assert_close(acc_top, ref_top)
466

467
    def test_trivial_graph(self):
468
        s = torch.full((2, 2), 2)
469
        tg = torch.jit.script(trivial_graph)
470
        o_ref = tg(s, s, s)
471
        tg_a = StaticModule(tg)
472
        o_test = tg_a(s, s, s)
473
        torch.testing.assert_close(o_ref, o_test)
474

475
    def test_leaky_relu(self):
476
        s = torch.randn(5, 5)
477
        tg = torch.jit.script(nn.LeakyReLU(0.1))
478
        o_ref = tg(s)
479
        tg_a = StaticModule(tg)
480
        o_test = tg_a(s)
481
        torch.testing.assert_close(o_ref, o_test)
482

483
    def test_attr(self):
484
        """
485
        TorchScript IR of TestModule() after freezing:
486
        graph(%self : __torch__.test_static_runtime.___torch_mangle_0.TestModule,
487
              %x.1 : Tensor):
488
            %18 : int = prim::Constant[value=30]()
489
            %30 : int = prim::Constant[value=13]()
490
            %3 : int = prim::Constant[value=20]()
491
            %2 : int = prim::Constant[value=1]()
492
            %self.sub2.a : int = prim::Constant[value=12]()
493
            %self.a : int = prim::Constant[value=3]()
494
            = prim::SetAttr[name="b"](%self, %3)
495
            %17 : Tensor = aten::add(%x.1, %30, %2)
496
            %7 : Tensor = aten::add(%17, %self.a, %2)
497
            %b.1 : int = prim::GetAttr[name="b"](%self)
498
            %9 : Tensor = aten::add(%7, %b.1, %2)
499
            %sub2 : __torch__.test_static_runtime.___torch_mangle_2.SubModule2 = prim::GetAttr[name="sub2"](%self)
500
            = prim::SetAttr[name="b"](%sub2, %18)
501
            %b : int = prim::GetAttr[name="b"](%sub2)
502
            %22 : int = aten::add(%self.sub2.a, %b)
503
            %23 : Tensor = aten::add(%x.1, %22, %2)
504
            %12 : Tensor = aten::add(%9, %23, %2)
505
            return (%12)
506
        """
507
        # test prim::SetAttr and prim::GetAttr impl in Static Runtime
508
        m = TestModule()
509

510
        m.eval()
511
        input = torch.randn(2, 2)
512
        output_s = m.forward(input)
513

514
        ms = torch.jit.script(m)
515
        sm = StaticModule(ms)
516
        output_sm = sm(input)
517
        torch.testing.assert_close(output_s, output_sm)
518
        sm.benchmark([input], {}, 2, 2)
519
        sm.benchmark_individual_ops([input], {}, 2, 2)
520
        sm.benchmark([], {"x": input}, 2, 2)
521
        sm.benchmark_individual_ops([], {"x": input}, 2, 2)
522

523
    @unittest.skip("Temporarily disabled")
524
    def test_fusion_trivial_graph(self):
525
        s = torch.full((2, 2), 2)
526
        tg = torch.jit.script(trivial_graph)
527
        o_ref = tg(s, s, s)
528
        torch._C._fuse_to_static_module(tg.graph)
529
        assert "StaticSubgraph" in str(tg.graph)
530
        o_test = tg(s, s, s)
531
        torch.testing.assert_close(o_ref, o_test)
532

533
    @unittest.skip("Temporarily disabled")
534
    def test_fusion_multihead_attention_layer(self):
535
        HID_DIM = 256
536
        QUERY_LEN = 8
537
        BATCH_SIZE = 128
538
        LAYERS = 3
539
        HEADS = 8
540
        DROPOUT = 0.1
541
        device = torch.device("cpu")
542
        attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
543
        with torch.no_grad():
544
            src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
545
        src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
546

547
        attention.eval()
548
        attention = torch.jit.script(attention)
549
        attention.eval()
550
        o_ref = attention(src, src, src, src_mask)
551

552
        torch._C._fuse_to_static_module(attention._c)
553
        o_test = attention(src, src, src, src_mask)
554

555
        for a, b in zip(o_ref, o_test):
556
            torch.testing.assert_close(a, b)
557

558
    @unittest.skip("Temporarily disabled")
559
    def test_fusion_loop(self):
560
        a = torch.randn(5, 5)
561
        b = torch.randn(5, 5)
562
        c = 4
563
        lg = torch.jit.script(loop_graph)
564
        o_ref = lg(a, b, c)
565
        torch._C._fuse_to_static_module(lg.graph)
566
        assert "StaticSubgraph" in str(lg.graph)
567
        o_test = lg(a, b, c)
568
        torch.testing.assert_close(o_ref, o_test)
569

570
    @unittest.skip("Temporarily disabled")
571
    def test_fusion_outputs(self):
572
        a = torch.randn(2, 2)
573
        b = torch.randn(2, 2)
574
        c = 4
575
        og = torch.jit.script(output_graph)
576
        o_ref = og(a, b, b, c)
577
        torch._C._fuse_to_static_module(og.graph)
578
        assert "StaticSubgraph" in str(og.graph)
579
        o_test = og(a, b, b, c)
580
        for i in o_ref.keys():
581
            torch.testing.assert_close(o_ref[i], o_test[i])
582

583
    def test_create_object(self):
584
        class Foo:  # noqa: B903
585
            def __init__(self, x: torch.Tensor) -> None:
586
                self.x = x
587

588
        class Mod(torch.nn.Module):
589
            def __init__(self) -> None:
590
                super().__init__()
591

592
            def forward(self, y: torch.Tensor) -> torch.Tensor:
593
                foo = Foo(y)
594
                return y * foo.x
595

596
        mod = torch.jit.script(Mod()).eval()
597
        y = torch.randn((1, ))
598
        expected = mod(y)
599

600
        static_mod = StaticModule(torch.jit.freeze(mod))
601
        actual = static_mod(y)
602

603
        self.assertEqual(expected, actual)
604

605
if __name__ == "__main__":
606
    run_tests()
607

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.