pytorch

Форк
0
/
test_fx_passes.py 
868 строк · 26.2 Кб
1
# Owner(s): ["module: fx.passes"]
2

3
from dataclasses import dataclass
4
import operator
5
import logging
6
import sys
7

8
import torch
9
from torch.fx._symbolic_trace import symbolic_trace
10

11
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
12
from torch.fx.passes.operator_support import OperatorSupport
13
from torch.fx.passes.utils.fuser_utils import fuse_by_partitions
14
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
15

16
from torch.testing._internal.common_utils import run_tests, parametrize, instantiate_parametrized_tests
17
from torch.testing._internal.jit_utils import JitTestCase
18

19
logging.basicConfig(level=logging.WARNING)
20
logger = logging.getLogger(__name__)
21

22
class TestModule(torch.nn.Module):
23
    def __init__(self) -> None:
24
        super().__init__()
25
        self.linear = torch.nn.Linear(4, 4)
26
        self.linear2 = torch.nn.Linear(4, 4)
27
        self.param = torch.nn.Parameter(torch.rand(4, 4))
28

29
    def forward(self, a, b, c):
30
        add = a + b
31

32
        linear_1 = self.linear(add)
33

34
        add_1 = add + c
35
        add_2 = add_1 + self.param
36
        add_3 = add_1 + linear_1
37
        add_4 = add_2 + add_3
38

39
        linear_2 = self.linear2(add_4)
40

41
        add_5 = linear_2 + add_4
42
        add_6 = add_5 + a
43
        relu = add_6.relu()
44

45
        return add_4, add_6, relu
46

47
class TestDeepModule(torch.nn.Module):
48
    def __init__(self) -> None:
49
        super().__init__()
50
        self.linear = torch.nn.Linear(4, 4)
51

52
    def forward(self, a, b, c):
53
        o = a + b
54
        o = o + 1.0
55

56
        # testing to avoid DFS uses in passes. Since Python has max recursion depth.
57
        for _ in range(sys.getrecursionlimit() + 1):
58
            o = o - c
59

60
        return o
61

62

63
class TestPartitionFunctions:
64
    @staticmethod
65
    def forward1(a, b, c):
66
        add = a + b
67
        add_1 = add + b
68
        add_2 = add_1 + c
69
        relu_1 = add_2.relu()
70
        add_3 = add_1 + add_2
71
        add_4 = add_1 + relu_1 + add_3
72
        relu_2 = add_4.relu()
73
        add_5 = relu_2 + add_4
74
        add_6 = add_5 + add_4
75
        return add_4, add_6
76

77
    @staticmethod
78
    def forward2(a, b, _):
79
        add = a + b
80
        add_1 = add + b
81
        relu_1 = add_1.relu()  # blocked by this
82
        add_3 = add_1 + relu_1
83
        add_4 = add_1 + add_3
84
        return add_4, add_1
85

86
    @staticmethod
87
    def forward3(a, b, c):
88
        add = a + b
89
        add_1 = a + c
90
        add_2 = b + c
91
        return add, add_1, add_2
92

93
    @staticmethod
94
    def forward4(a, b, c):
95
        add = a + b
96
        add_1 = a + c
97
        add_2 = b + c
98
        return torch.where(add > 0, add_1, add_2)
99

100
    @staticmethod
101
    def forward5(a, b, c):
102
        # add should be fused right branch, as left branch is not supported
103
        add = a + 1
104
        # left branch
105
        relu = add.relu()
106
        # right branch
107
        add_1 = add + 2
108
        return relu, add_1
109

110
    @staticmethod
111
    def forward6(a, b, c):
112
        # add should have its own partition, as neither branchs are supported
113
        add = a + 1
114
        # left branch
115
        relu = add.relu()
116
        # right branch
117
        relu_1 = add.relu()
118
        return relu, relu_1
119

120
    @staticmethod
121
    def forward7(a, b, c):
122
        # both branches are supported, all adds should be fused together
123
        add = a + 1
124
        # left branch
125
        add_1 = add + 2
126
        # right branch is larger
127
        add_2 = add + 1
128
        add_3 = add_2 + 1
129
        return add_3, add_1
130

131
    @staticmethod
132
    def forward8(a, b, c):
133
        # both branches are in the same partition, add should join the same partition
134
        add = a + 1
135
        # left branch
136
        add_1 = add + 2
137
        # right branch
138
        add_2 = add + 1
139
        # left and right branch merges
140
        add_3 = add_2 + add_1
141

142
        return add_3
143

144
    @staticmethod
145
    def forward9(a, b, c):
146
        add = a + 1
147
        # branch 1
148
        add_1 = add + 1
149
        # branch 2
150
        add_2 = add + 1
151
        # branch_3
152
        add_3 = add + 1
153
        out = torch.stack([add_1, add_2, add_3])
154
        return out
155

156
    @staticmethod
157
    def forward10(a, b, c):
158
        add = a + 1
159
        # branch 1
160
        add_1 = add + 1
161
        # branch 2
162
        add_2 = add + 1
163
        # branch 3: depends on branch 2
164
        add_3 = add + add_2
165
        out = torch.stack([add_1, add_2, add_3])
166
        return out
167

168
    @staticmethod
169
    def forward11(a, b, c):
170
        add = a + 1
171
        # branch 1
172
        add_1 = add.relu()
173
        # branch 2 depends on branch 1
174
        add_2 = add + add_1
175
        # branch 3
176
        add_3 = add.relu()
177
        out = torch.stack([add_1, add_2, add_3])
178
        return out
179

180
    @staticmethod
181
    def forward12(a, b, c):
182
        b0 = a + 1.0
183
        c0 = a + 1.5
184
        x0 = b0.relu()
185
        x1 = c0.relu()
186
        b1 = b0 + x1
187
        c1 = c0 + 1.2
188
        # c2 has dependency on x0 & b0, when we merge {c0, c1, c2}
189
        # this dependency should be updated to the fusion group and reflected
190
        # on the decision to not fuse b0 & b1, which forms a cyclic dependency in
191
        # the new graph
192
        c2 = x0 + c0
193
        return b1, c2
194

195
    @staticmethod
196
    def forward13(a, b, c):
197
        a0, a1, a2, a3 = a.split(1, 0)
198
        b1 = a0 + b
199
        c1 = a1 + c
200
        return b1 + c1
201

202
    @staticmethod
203
    def forward14(a, b, c):
204
        a0, a1 = torch.ops.aten.std_mean(a)
205
        out = a0 + 1.0
206
        return out
207

208
    @staticmethod
209
    def forward15(a, b, c):
210
        a0 = torch.ops.aten.view(a, [2, 2])
211
        a1 = torch.ops.aten.permute(a0, [1, 0])
212
        a2 = a1 + 1.0
213
        a3 = torch.ops.aten.permute(a2, [1, 0])
214
        a4 = a3 + 1.0
215
        a5 = torch.ops.aten.permute(a4, [1, 0])
216
        return torch.ops.aten.permute(a5, [1, 0])
217

218
    @staticmethod
219
    def forward16(a, b, c):
220
        a0 = a - 1.0
221
        a1 = torch.ops.aten.view(a0, [2, 2])
222
        a2 = torch.ops.aten.permute(a1, [1, 0])
223
        a3 = a2 + 1.0
224
        a4 = torch.ops.aten.permute(a3, [1, 0])
225
        a5 = a4 + 1.0
226
        a6 = torch.ops.aten.permute(a5, [1, 0])
227
        a7 = torch.ops.aten.permute(a6, [1, 0])
228
        return a7 - 1.0
229

230
    @staticmethod
231
    def forward17(a, b, c, d, e, f):
232
        a0 = a + b
233
        a1 = c + d
234
        a2 = e + f
235
        return a0, a1, a2
236

237
    @staticmethod
238
    def forward18(a, b, c):
239
        a0, a1 = torch.ops.aten.var_mean(a)
240
        return a0
241

242
# A mock OperatorSupport class, where only operator.add is supported
243
class MockOperatorSupport(OperatorSupport):
244
    def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
245
        return (node.op == "call_function" and
246
                node.target in {operator.add, operator.getitem,
247
                                torch.ops.aten.view,
248
                                torch.ops.aten.permute,
249
                                torch.ops.aten.std_mean})
250

251
@instantiate_parametrized_tests
252
class TestFXGraphPasses(JitTestCase):
253

254
    @parametrize("fn, expected_partition, bookend_non_compute_pass", [
255
        (TestPartitionFunctions.forward1, [["add_7", "add_6"], ["add_5", "add_4", "add_3"], ["add_2", "add_1", "add"]], False),
256
        (TestPartitionFunctions.forward2, [["add_3", "add_2"], ["add_1", "add"]], False),
257

258
        # 1 horizontal fusion with common producer
259
        (TestPartitionFunctions.forward3, [["add_2", "add_1", "add"]], False),
260
        (TestPartitionFunctions.forward4, [["add_2", "add_1", "add"]], False),
261

262
        # 2 branches cases
263
        (TestPartitionFunctions.forward5, [["add_1", "add"]], False),
264
        (TestPartitionFunctions.forward6, [["add"]], False),
265
        (TestPartitionFunctions.forward7, [["add_3", "add_2", "add", "add_1"]], False),
266
        (TestPartitionFunctions.forward8, [["add_3", "add_2", "add", "add_1"]], False),
267

268
        # 3 branch cases
269
        (TestPartitionFunctions.forward9, [['add_3', 'add_2', 'add_1', 'add']], False),
270
        (TestPartitionFunctions.forward10, [['add_3', 'add_2', 'add', 'add_1']], False),
271
        (TestPartitionFunctions.forward11, [['add_1'], ['add']], False),
272

273
        # 4 not necessarily the only partition, just to verify that there's no cyclic dependency after partition
274
        (TestPartitionFunctions.forward12, [["add_2", "add_3", "add_4"], ["add", "add_1"]], False),
275

276
        # 5 getitem special case
277
        (TestPartitionFunctions.forward13, [["add_2", "add_1", "add"]], False),
278
        (TestPartitionFunctions.forward14, [["add", "std_mean", "getitem", "getitem_1"]], False),
279

280
        # 6 bookend non_compute pass
281
        (TestPartitionFunctions.forward15, [["permute_1", "add_1", "add"]], True),
282
        (TestPartitionFunctions.forward15, [['add_1', 'add', 'permute_1', 'view', 'permute_2', 'permute_3', 'permute']], False),
283
        (TestPartitionFunctions.forward16, [["permute_1", "add_1", "add"]], True),
284
        (TestPartitionFunctions.forward16, [['add_1', 'add', 'permute_1', 'view', 'permute_2', 'permute_3', 'permute']], False),
285
        # should be empty partition, not a partiton with empty nodes
286
        (TestPartitionFunctions.forward18, [], False),
287
    ])
288
    def test_partitioner(self, fn, expected_partition, bookend_non_compute_pass):
289
        traced = symbolic_trace(fn)
290

291
        non_compute_ops = []
292
        if bookend_non_compute_pass:
293
            non_compute_ops = ["torch.ops.aten.view", "torch.ops.aten.permute"]
294

295
        supported_ops = MockOperatorSupport()
296
        partitioner = CapabilityBasedPartitioner(traced,
297
                                                 supported_ops,
298
                                                 allows_single_node_partition=True,
299
                                                 non_compute_ops=non_compute_ops)
300
        partitions = partitioner.propose_partitions()
301
        if bookend_non_compute_pass:
302
            partitioner.remove_bookend_non_compute_ops(partitions)
303

304
        partitions_name = [[node.name for node in partition.nodes] for partition in partitions]
305
        assert len(partitions_name) == len(expected_partition)
306
        for i in range(len(partitions_name)):
307
            assert set(partitions_name[i]) == set(expected_partition[i])
308

309
        fused_graph = partitioner.fuse_partitions(partitions)
310

311
        a, b, c = torch.rand(4), torch.rand(4), torch.rand(4)
312

313
        expected = fn(a, b, c)
314
        result = fused_graph(a, b, c)
315
        torch.testing.assert_close(expected, result)
316

317
    @parametrize("fn, expected_partition", [
318
        (TestPartitionFunctions.forward17, [['add', 'add_1', 'add_2']]),
319
    ])
320
    def test_partitioner_independent_output(self, fn, expected_partition):
321
        traced = symbolic_trace(fn)
322

323
        supported_ops = MockOperatorSupport()
324
        partitioner = CapabilityBasedPartitioner(traced,
325
                                                 supported_ops,
326
                                                 allows_single_node_partition=True)
327
        partitions = partitioner.propose_partitions()
328
        partitions_name = [[node.name for node in partition.nodes] for partition in partitions]
329
        assert len(partitions_name) == len(expected_partition)
330
        for i in range(len(partitions_name)):
331
            assert set(partitions_name[i]) == set(expected_partition[i])
332

333
        fused_graph = partitioner.fuse_partitions(partitions)
334

335
        a, b, c, d, e, f = torch.rand(4), torch.rand(4), torch.rand(4), torch.rand(4), torch.rand(4), torch.rand(4)
336

337
        expected = fn(a, b, c, d, e, f)
338
        result = fused_graph(a, b, c, d, e, f)
339
        torch.testing.assert_close(expected, result)
340

341
    @parametrize("partition", [
342
        [['add', 'add_1'], ['add_5', 'add_6']],
343
        [['add', 'add_1', 'add_2']],  # vertical fusion
344
        [['add_2', 'add_3']],         # horizontal fusion
345
        [['add_3', 'add_4']],
346
        [['add_6', 'add_5']],     # arbitray node order
347
        [['add_4', 'add_1', 'add_3', 'add_2']],           # arbitray node order
348
        [['add_5', 'add_6'], ['add_1', 'add_2', 'add_3', 'add_4']],  # arbitray partition order
349
        [['add_5', 'linear2']],   # includes call_function + call_module node
350
        [['add_6', 'relu']],   # includes call_function + call_module node
351
        [['param', 'add_2']],   # includes get_attr + call_module nodes
352
        [['param', 'add_1', 'linear']],   # includes get_attr + call_function + call_module nodes
353
        [["add", "linear", "add_1", "param", "add_2", "add_3", "add_4", "linear2", "add_5", "add_6", "relu"]],  # full graph
354
    ])
355
    def test_fuser_util(self, partition):
356
        m = TestModule()
357
        gm = symbolic_trace(m)
358

359
        nodes_by_name = {node.name : node for node in gm.graph.nodes}
360

361
        partitions = []
362
        for node_names in partition:
363
            partitions.append([nodes_by_name[name] for name in node_names])
364

365
        fused_graph = fuse_by_partitions(gm, partitions)
366

367
        a, b, c = torch.rand(4), torch.rand(4), torch.rand(4)
368

369
        expected = m(a, b, c)
370
        result = fused_graph(a, b, c)
371

372
        torch.testing.assert_close(expected, result)
373

374
    @parametrize("partition", [
375
        [['add', 'add_1'], ['add_1', 'add_5', 'add_6']],  # add_1 exists in multiple partitions
376
        [['add', 'add_1', 'add_3']],    # invalid partition: circular dependency
377
        [['add_4', 'add_5']],    # invalid partition: circular dependency
378
        [['relu', 'add_5']],    # invalid partition: circular dependency
379
    ])
380
    def test_fuser_util_xfail(self, partition):
381
        m = TestModule()
382
        gm = symbolic_trace(m)
383

384
        nodes_by_name = {node.name : node for node in gm.graph.nodes}
385

386
        partitions = []
387
        for node_names in partition:
388
            partitions.append([nodes_by_name[name] for name in node_names])
389

390
        with self.assertRaises(Exception):
391
            fuse_by_partitions(gm, partitions)
392

393
    def test_fuser_pass_deep_model(self):
394
        m = TestDeepModule()
395
        traced = symbolic_trace(m)
396

397
        supported_ops = MockOperatorSupport()
398
        partitioner = CapabilityBasedPartitioner(traced,
399
                                                 supported_ops,
400
                                                 allows_single_node_partition=True)
401
        partitions = partitioner.propose_partitions()
402

403
@dataclass
404
class TestCase:
405
    match_output: bool
406
    match_placeholder: bool
407
    num_matches: int
408
    remove_overlapping_matches: bool = True
409

410
class SingleNodePattern:
411
    @staticmethod
412
    def forward(x):
413
        val = torch.neg(x)
414
        return torch.add(val, val)
415

416
    @staticmethod
417
    def pattern(a):
418
        return torch.neg(a)
419

420
    test_cases = [
421
        # match_output, match_placeholder, num_matches
422
        TestCase(False, False, 1),
423
        TestCase(True, False, 0),
424
        TestCase(False, True, 1),
425
        TestCase(True, True, 0)
426
    ]
427
class SimplePattern:
428
    @staticmethod
429
    def forward(x, w1, w2):
430
        m1 = torch.cat([w1, w2]).sum()
431
        m2 = torch.cat([w2, w1]).sum()
432
        m3 = torch.cat([m1, m2]).sum()
433
        return x + torch.max(m1) + torch.max(m2) + m3
434

435
    @staticmethod
436
    def pattern(a, b):
437
        return torch.cat([a, b]).sum()
438

439
    test_cases = [
440
        # match_output, match_placeholder, num_matches
441
        TestCase(False, False, 3),
442
        TestCase(True, False, 0),
443
        TestCase(False, True, 2),
444
        TestCase(True, True, 0)
445
    ]
446

447
class SimpleFullGraphMatching:
448
    @staticmethod
449
    def forward(x):
450
        a = torch.neg(x)
451
        return torch.add(a, a)
452

453
    @staticmethod
454
    def pattern(x):
455
        a = torch.neg(x)
456
        return torch.add(a, a)
457

458
    test_cases = [
459
        # match_output, match_placeholder, num_matches
460
        TestCase(False, False, 1),
461
        TestCase(True, False, 1),
462
        TestCase(False, True, 1),
463
        TestCase(True, True, 1)
464
    ]
465

466
class DiamondShapePatternTestCase:
467
    @staticmethod
468
    def forward(x):
469
        a = torch.neg(x)
470

471
        a = a.relu()
472
        left = a.sigmoid()
473
        right = a.relu()
474
        out = left + right
475

476
        return out
477

478
    @staticmethod
479
    def pattern(a):
480
        a = a.relu()
481
        left = a.sigmoid()
482
        right = a.relu()
483
        out = left + right
484
        return out
485

486
    test_cases = [
487
        # match_output, match_placeholder, num_matches
488
        TestCase(False, False, 1),
489
        TestCase(True, False, 1),
490
        TestCase(False, True, 0),
491
        TestCase(True, True, 0)
492
    ]
493

494
class NonFullyContainedMatches:
495
    @staticmethod
496
    def forward(x, w1, w2, b1, b2):
497
        # fully contained matched subgraph
498
        m1 = torch.cat([w1, w2])
499
        m2 = torch.cat([x, b2])
500
        t0 = torch.addmm(b1, m1, m2.t())
501
        t0_sum = torch.sum(t0)   # use of t0 is not leaking
502

503
        # leaking matched subgraph, m3 is leaked
504
        m3 = torch.cat([w1, w2])
505
        m4 = torch.cat([x, b2])
506
        t1 = torch.addmm(b1, m3, m4.t())
507
        m3_sum = torch.sum(m3)
508

509
        return t0_sum, m3_sum
510

511
    @staticmethod
512
    def pattern(x, w1, w2, b1, b2):
513
        m1 = torch.cat([w1, w2])
514
        m2 = torch.cat([x, b2])
515
        return torch.addmm(b1, m1, m2.t())
516

517
    test_cases = [
518
        # match_output, match_placeholder, num_matches
519
        TestCase(False, False, 1),
520

521
        TestCase(True, False, 0),
522

523
        TestCase(False, True, 1),     # leaked used of placeholder is not leaking
524
    ]
525

526
class ChainRepeatedPattern:
527
    @staticmethod
528
    def forward(x):
529
        x = torch.sigmoid(x)
530
        x = torch.sigmoid(x)
531
        x = torch.sigmoid(x)
532
        return torch.sigmoid(x)
533

534
    @staticmethod
535
    def pattern(x):
536
        return torch.sigmoid(torch.sigmoid(x))
537

538
    test_cases = [
539
        # match_output, match_placeholder, num_matches
540
        TestCase(False, False, 3, remove_overlapping_matches=False),
541
        TestCase(False, False, 2, remove_overlapping_matches=True),
542
        TestCase(True, False, 1),
543
        TestCase(False, True, 1),
544
        TestCase(True, True, 0)
545
    ]
546

547
class QuantizationModel:
548
    @staticmethod
549
    def forward(x):
550
        x += 3
551
        x = x.dequantize()
552
        x = torch.sigmoid(x)
553
        x = x.to(torch.float16)
554
        return x
555

556
    @staticmethod
557
    def pattern(x):
558
        x = x.dequantize()
559
        x = torch.sigmoid(x)
560
        x = x.to(torch.float16)
561
        return x
562

563
    test_cases = [
564
        # match_output, match_placeholder, num_matches
565
        TestCase(False, False, 1),
566
        TestCase(True, False, 1),
567
        TestCase(False, True, 0),
568
        TestCase(True, True, 0)
569
    ]
570

571
class MultipleOutputsWithDependency:
572
    @staticmethod
573
    def forward(x):
574
        y = x.relu()
575
        z = y.sigmoid()
576
        return z, y
577

578
    @staticmethod
579
    def pattern(a):
580
        b = a.relu()
581
        c = b.sigmoid()
582
        return b, c     # outputs have data dependency
583

584
    test_cases = [
585
        # match_output, match_placeholder, num_matches
586
        TestCase(False, False, 1),
587
        TestCase(True, False, 0),
588
        TestCase(False, True, 1),
589
        TestCase(True, True, 0)
590
    ]
591

592
class MultipleOutputsWithoutDependency:
593
    @staticmethod
594
    def forward(x):
595
        x = x + 1
596

597
        # target subgraph to match
598
        x = x.relu()
599
        z = x.sum()
600
        y = x.sigmoid()
601

602
        out = y.sigmoid() + z.sum()
603
        return out
604

605
    @staticmethod
606
    def pattern(a):
607
        a = a.relu()
608
        b = a.sigmoid()
609
        c = a.sum()
610
        return b, c
611

612
    test_cases = [
613
        # match_output, match_placeholder, num_matches
614
        TestCase(False, False, 1),
615
        TestCase(True, False, 0),
616
        TestCase(False, True, 0),
617
        TestCase(True, True, 0)
618
    ]
619

620
class MultipleOutputsMultipleOverlappingMatches:
621
    @staticmethod
622
    def forward(x):
623
        x = x + 1
624

625
        # target subgraph to match
626
        x = x.relu()
627
        z = x.sum()
628
        z1 = x.sum()
629
        y = x.sigmoid()
630
        y1 = x.sigmoid()
631

632
        return z + z1 + y + y1
633

634
    @staticmethod
635
    def pattern(a):
636
        a = a.relu()
637
        b = a.sigmoid()
638
        c = a.sum()
639
        return a, b, c
640

641
    test_cases = [
642
        # match_output, match_placeholder, num_matches
643
        TestCase(False, False, 4, remove_overlapping_matches=False),
644
        TestCase(False, False, 1, remove_overlapping_matches=True),
645
    ]
646

647
class MultipleOutputsMultipleNonOverlappingMatches:
648
    @staticmethod
649
    def forward(x):
650
        x = x + 1
651

652
        # target subgraph to match
653
        x = x.relu()
654
        z = x.sum()
655
        y = x.sigmoid()
656

657
        x = x.relu()
658
        z1 = x.sum()
659
        y1 = x.sigmoid()
660

661
        return z + z1 + y + y1
662

663
    @staticmethod
664
    def pattern(a):
665
        a = a.relu()
666
        b = a.sigmoid()
667
        c = a.sum()
668
        return b, c
669

670
    test_cases = [
671
        # match_output, match_placeholder, num_matches
672
        TestCase(False, False, 1),
673
    ]
674

675
class MultipleOutputsIdenticalAnchor:
676
    @staticmethod
677
    def forward(x):
678
        x = x + 1
679

680
        # target subgraph to match
681
        x = x.relu()
682
        y = x.sigmoid()
683
        y1 = x.sigmoid()
684

685
        return y, y1
686

687
    @staticmethod
688
    def pattern(a):
689
        a = a.relu()
690
        b = a.sigmoid()
691
        b1 = a.sigmoid()
692
        return b, b1
693

694
    test_cases = [
695
        # match_output, match_placeholder, num_matches
696
        # (False, False, 2),  # FIXME: currently still matches to 2, should fix to 1
697
        TestCase(True, False, 1),
698
        TestCase(False, True, 0),
699
    ]
700

701

702
class MultipleOutputsHorizontalPattern:
703
    @staticmethod
704
    def forward(x):
705
        x = x + 1
706

707
        # target subgraph to match
708
        y1 = x.relu()
709
        y2 = x.sigmoid()
710

711
        return y1, y2
712

713
    @staticmethod
714
    def pattern(a):
715
        b1 = a.relu()
716
        b2 = a.sigmoid()
717

718
        return b1, b2
719

720
    test_cases = [
721
        # match_output, match_placeholder, num_matches
722
        TestCase(False, False, 1),
723
        TestCase(True, False, 1),
724
        TestCase(False, True, 0),
725
        TestCase(True, True, 0)
726
    ]
727

728
class MultiOutputWithWithInvalidMatches:
729
    @staticmethod
730
    def forward(x):
731
        res0 = torch.nn.functional.linear(x, torch.rand(3, 3))
732
        res1 = torch.sigmoid(res0)
733
        res2 = res0 * res1
734
        res3 = torch.sum(res2, dim=1)
735
        return res3
736

737
    @staticmethod
738
    def pattern(a, b, c):
739
        lin_res = torch.nn.functional.linear(a, b)
740
        mul_res = lin_res * c
741
        return lin_res, mul_res
742

743
    test_cases = [
744
        # match_output, match_placeholder, num_matches
745
        TestCase(False, False, 0),
746
        TestCase(True, False, 0),
747
        TestCase(False, True, 0),
748
    ]
749

750
class QuantizationFp8Pattern:
751
    @classmethod
752
    def setup(cls):
753
        cls.quantization = torch.library.Library("fp8_quantization", "DEF")  # noqa: TOR901
754
        cls.quantization.define("quantize_per_tensor_affine_fp8(Tensor self, int dtype, float scale) -> Tensor")
755
        cls.quantization.define("dequantize_per_tensor_affine_fp8(Tensor self, int dtype, float scale) -> Tensor")
756

757
    @classmethod
758
    def tearDown(cls):
759
        del cls.quantization
760

761
    @staticmethod
762
    def forward(self, arg0_1, arg1_1):
763
        qt = torch.ops.fp8_quantization
764
        _scale_0 = self._scale_0
765
        quantize_per_tensor_affine_fp8 = qt.quantize_per_tensor_affine_fp8(arg0_1, 0, _scale_0)
766
        dequantize_per_tensor_affine_fp8 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8, 0, _scale_0)
767
        _scale_1 = self._scale_0
768
        quantize_per_tensor_affine_fp8_1 = qt.quantize_per_tensor_affine_fp8(arg1_1, 0, _scale_1)
769
        dequantize_per_tensor_affine_fp8_1 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8_1, 0, _scale_1)
770
        add = torch.ops.aten.add.Tensor(dequantize_per_tensor_affine_fp8, dequantize_per_tensor_affine_fp8_1)
771
        _scale_2 = self._scale_0
772
        quantize_per_tensor_affine_fp8_2 = qt.quantize_per_tensor_affine_fp8(add, 0, _scale_2)
773
        dequantize_per_tensor_affine_fp8_2 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8_2, 0, _scale_2)
774
        return dequantize_per_tensor_affine_fp8_2
775

776
    @staticmethod
777
    def pattern(a, a_dtype, a_scale, b, b_dtype, b_scale, out_scale):
778
        qt = torch.ops.fp8_quantization
779
        a = qt.dequantize_per_tensor_affine_fp8(a, a_dtype, a_scale)
780
        b = qt.dequantize_per_tensor_affine_fp8(b, b_dtype, b_scale)
781
        output = torch.ops.aten.add.Tensor(a, b)
782

783
        qt.dequantize_per_tensor_affine_fp8
784

785
        output = qt.quantize_per_tensor_affine_fp8(output, a_dtype, out_scale)
786
        return output
787

788
    test_cases = [
789
        # match_output, match_placeholder, num_matches
790
        TestCase(False, False, 1),
791
    ]
792

793
class NoAnchorFound:
794
    # This test case is for pattern where no matching anchor is found in the target graph
795
    # `anchor` is the starting point of the pattern matching, it's usually the boundary returning nodes
796
    @staticmethod
797
    def forward(x):
798
        x = x + 1
799
        return x
800

801
    @staticmethod
802
    def pattern(a):
803
        b1 = a.relu()
804
        return b1
805

806
    test_cases = [
807
        # match_output, match_placeholder, num_matches
808
        TestCase(False, False, 0),
809
        TestCase(True, False, 0),
810
        TestCase(False, True, 0),
811
        TestCase(True, True, 0)
812
    ]
813

814
@instantiate_parametrized_tests
815
class TestFXMatcherUtils(JitTestCase):
816

817
    @parametrize("test_model", [
818
        SingleNodePattern,
819
        SimplePattern,
820
        SimpleFullGraphMatching,
821
        DiamondShapePatternTestCase,
822
        NonFullyContainedMatches,
823
        ChainRepeatedPattern,
824
        QuantizationModel,
825
        MultipleOutputsWithDependency,
826
        MultipleOutputsWithoutDependency,
827
        MultipleOutputsMultipleOverlappingMatches,
828
        MultipleOutputsMultipleNonOverlappingMatches,
829
        MultipleOutputsIdenticalAnchor,
830
        MultipleOutputsHorizontalPattern,
831
        MultiOutputWithWithInvalidMatches,
832
        QuantizationFp8Pattern,
833
        NoAnchorFound,
834
    ])
835
    def test_subgraph_matcher(self, test_model):
836

837
        setup = getattr(test_model, "setup", None)
838
        if callable(setup):
839
            setup()
840

841
        traced = symbolic_trace(test_model.forward)
842
        pattern_traced = symbolic_trace(test_model.pattern)
843

844
        for test_case in test_model.test_cases:
845

846
            matcher = SubgraphMatcher(pattern_traced.graph,
847
                                      match_output=test_case.match_output,
848
                                      match_placeholder=test_case.match_placeholder,
849
                                      remove_overlapping_matches=test_case.remove_overlapping_matches)
850
            matches = matcher.match(traced.graph)
851

852
            assert len(matches) == test_case.num_matches
853

854
            for match in matches:
855
                for node in pattern_traced.graph.nodes:
856
                    if not test_case.match_placeholder and node.op == "placeholder":
857
                        continue
858
                    if not test_case.match_output and node.op == "output":
859
                        continue
860
                    assert node in match.nodes_map
861

862
        tearDown = getattr(test_model, "tearDown", None)
863
        if callable(setup):
864
            tearDown()
865

866

867
if __name__ == "__main__":
868
    run_tests()
869

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.