1
# Owner(s): ["module: fx.passes"]
3
from dataclasses import dataclass
9
from torch.fx._symbolic_trace import symbolic_trace
11
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
12
from torch.fx.passes.operator_support import OperatorSupport
13
from torch.fx.passes.utils.fuser_utils import fuse_by_partitions
14
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
16
from torch.testing._internal.common_utils import run_tests, parametrize, instantiate_parametrized_tests
17
from torch.testing._internal.jit_utils import JitTestCase
19
logging.basicConfig(level=logging.WARNING)
20
logger = logging.getLogger(__name__)
22
class TestModule(torch.nn.Module):
23
def __init__(self) -> None:
25
self.linear = torch.nn.Linear(4, 4)
26
self.linear2 = torch.nn.Linear(4, 4)
27
self.param = torch.nn.Parameter(torch.rand(4, 4))
29
def forward(self, a, b, c):
32
linear_1 = self.linear(add)
35
add_2 = add_1 + self.param
36
add_3 = add_1 + linear_1
39
linear_2 = self.linear2(add_4)
41
add_5 = linear_2 + add_4
45
return add_4, add_6, relu
47
class TestDeepModule(torch.nn.Module):
48
def __init__(self) -> None:
50
self.linear = torch.nn.Linear(4, 4)
52
def forward(self, a, b, c):
56
# testing to avoid DFS uses in passes. Since Python has max recursion depth.
57
for _ in range(sys.getrecursionlimit() + 1):
63
class TestPartitionFunctions:
65
def forward1(a, b, c):
71
add_4 = add_1 + relu_1 + add_3
73
add_5 = relu_2 + add_4
78
def forward2(a, b, _):
81
relu_1 = add_1.relu() # blocked by this
82
add_3 = add_1 + relu_1
87
def forward3(a, b, c):
91
return add, add_1, add_2
94
def forward4(a, b, c):
98
return torch.where(add > 0, add_1, add_2)
101
def forward5(a, b, c):
102
# add should be fused right branch, as left branch is not supported
111
def forward6(a, b, c):
112
# add should have its own partition, as neither branchs are supported
121
def forward7(a, b, c):
122
# both branches are supported, all adds should be fused together
126
# right branch is larger
132
def forward8(a, b, c):
133
# both branches are in the same partition, add should join the same partition
139
# left and right branch merges
140
add_3 = add_2 + add_1
145
def forward9(a, b, c):
153
out = torch.stack([add_1, add_2, add_3])
157
def forward10(a, b, c):
163
# branch 3: depends on branch 2
165
out = torch.stack([add_1, add_2, add_3])
169
def forward11(a, b, c):
173
# branch 2 depends on branch 1
177
out = torch.stack([add_1, add_2, add_3])
181
def forward12(a, b, c):
188
# c2 has dependency on x0 & b0, when we merge {c0, c1, c2}
189
# this dependency should be updated to the fusion group and reflected
190
# on the decision to not fuse b0 & b1, which forms a cyclic dependency in
196
def forward13(a, b, c):
197
a0, a1, a2, a3 = a.split(1, 0)
203
def forward14(a, b, c):
204
a0, a1 = torch.ops.aten.std_mean(a)
209
def forward15(a, b, c):
210
a0 = torch.ops.aten.view(a, [2, 2])
211
a1 = torch.ops.aten.permute(a0, [1, 0])
213
a3 = torch.ops.aten.permute(a2, [1, 0])
215
a5 = torch.ops.aten.permute(a4, [1, 0])
216
return torch.ops.aten.permute(a5, [1, 0])
219
def forward16(a, b, c):
221
a1 = torch.ops.aten.view(a0, [2, 2])
222
a2 = torch.ops.aten.permute(a1, [1, 0])
224
a4 = torch.ops.aten.permute(a3, [1, 0])
226
a6 = torch.ops.aten.permute(a5, [1, 0])
227
a7 = torch.ops.aten.permute(a6, [1, 0])
231
def forward17(a, b, c, d, e, f):
238
def forward18(a, b, c):
239
a0, a1 = torch.ops.aten.var_mean(a)
242
# A mock OperatorSupport class, where only operator.add is supported
243
class MockOperatorSupport(OperatorSupport):
244
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
245
return (node.op == "call_function" and
246
node.target in {operator.add, operator.getitem,
248
torch.ops.aten.permute,
249
torch.ops.aten.std_mean})
251
@instantiate_parametrized_tests
252
class TestFXGraphPasses(JitTestCase):
254
@parametrize("fn, expected_partition, bookend_non_compute_pass", [
255
(TestPartitionFunctions.forward1, [["add_7", "add_6"], ["add_5", "add_4", "add_3"], ["add_2", "add_1", "add"]], False),
256
(TestPartitionFunctions.forward2, [["add_3", "add_2"], ["add_1", "add"]], False),
258
# 1 horizontal fusion with common producer
259
(TestPartitionFunctions.forward3, [["add_2", "add_1", "add"]], False),
260
(TestPartitionFunctions.forward4, [["add_2", "add_1", "add"]], False),
263
(TestPartitionFunctions.forward5, [["add_1", "add"]], False),
264
(TestPartitionFunctions.forward6, [["add"]], False),
265
(TestPartitionFunctions.forward7, [["add_3", "add_2", "add", "add_1"]], False),
266
(TestPartitionFunctions.forward8, [["add_3", "add_2", "add", "add_1"]], False),
269
(TestPartitionFunctions.forward9, [['add_3', 'add_2', 'add_1', 'add']], False),
270
(TestPartitionFunctions.forward10, [['add_3', 'add_2', 'add', 'add_1']], False),
271
(TestPartitionFunctions.forward11, [['add_1'], ['add']], False),
273
# 4 not necessarily the only partition, just to verify that there's no cyclic dependency after partition
274
(TestPartitionFunctions.forward12, [["add_2", "add_3", "add_4"], ["add", "add_1"]], False),
276
# 5 getitem special case
277
(TestPartitionFunctions.forward13, [["add_2", "add_1", "add"]], False),
278
(TestPartitionFunctions.forward14, [["add", "std_mean", "getitem", "getitem_1"]], False),
280
# 6 bookend non_compute pass
281
(TestPartitionFunctions.forward15, [["permute_1", "add_1", "add"]], True),
282
(TestPartitionFunctions.forward15, [['add_1', 'add', 'permute_1', 'view', 'permute_2', 'permute_3', 'permute']], False),
283
(TestPartitionFunctions.forward16, [["permute_1", "add_1", "add"]], True),
284
(TestPartitionFunctions.forward16, [['add_1', 'add', 'permute_1', 'view', 'permute_2', 'permute_3', 'permute']], False),
285
# should be empty partition, not a partiton with empty nodes
286
(TestPartitionFunctions.forward18, [], False),
288
def test_partitioner(self, fn, expected_partition, bookend_non_compute_pass):
289
traced = symbolic_trace(fn)
292
if bookend_non_compute_pass:
293
non_compute_ops = ["torch.ops.aten.view", "torch.ops.aten.permute"]
295
supported_ops = MockOperatorSupport()
296
partitioner = CapabilityBasedPartitioner(traced,
298
allows_single_node_partition=True,
299
non_compute_ops=non_compute_ops)
300
partitions = partitioner.propose_partitions()
301
if bookend_non_compute_pass:
302
partitioner.remove_bookend_non_compute_ops(partitions)
304
partitions_name = [[node.name for node in partition.nodes] for partition in partitions]
305
assert len(partitions_name) == len(expected_partition)
306
for i in range(len(partitions_name)):
307
assert set(partitions_name[i]) == set(expected_partition[i])
309
fused_graph = partitioner.fuse_partitions(partitions)
311
a, b, c = torch.rand(4), torch.rand(4), torch.rand(4)
313
expected = fn(a, b, c)
314
result = fused_graph(a, b, c)
315
torch.testing.assert_close(expected, result)
317
@parametrize("fn, expected_partition", [
318
(TestPartitionFunctions.forward17, [['add', 'add_1', 'add_2']]),
320
def test_partitioner_independent_output(self, fn, expected_partition):
321
traced = symbolic_trace(fn)
323
supported_ops = MockOperatorSupport()
324
partitioner = CapabilityBasedPartitioner(traced,
326
allows_single_node_partition=True)
327
partitions = partitioner.propose_partitions()
328
partitions_name = [[node.name for node in partition.nodes] for partition in partitions]
329
assert len(partitions_name) == len(expected_partition)
330
for i in range(len(partitions_name)):
331
assert set(partitions_name[i]) == set(expected_partition[i])
333
fused_graph = partitioner.fuse_partitions(partitions)
335
a, b, c, d, e, f = torch.rand(4), torch.rand(4), torch.rand(4), torch.rand(4), torch.rand(4), torch.rand(4)
337
expected = fn(a, b, c, d, e, f)
338
result = fused_graph(a, b, c, d, e, f)
339
torch.testing.assert_close(expected, result)
341
@parametrize("partition", [
342
[['add', 'add_1'], ['add_5', 'add_6']],
343
[['add', 'add_1', 'add_2']], # vertical fusion
344
[['add_2', 'add_3']], # horizontal fusion
345
[['add_3', 'add_4']],
346
[['add_6', 'add_5']], # arbitray node order
347
[['add_4', 'add_1', 'add_3', 'add_2']], # arbitray node order
348
[['add_5', 'add_6'], ['add_1', 'add_2', 'add_3', 'add_4']], # arbitray partition order
349
[['add_5', 'linear2']], # includes call_function + call_module node
350
[['add_6', 'relu']], # includes call_function + call_module node
351
[['param', 'add_2']], # includes get_attr + call_module nodes
352
[['param', 'add_1', 'linear']], # includes get_attr + call_function + call_module nodes
353
[["add", "linear", "add_1", "param", "add_2", "add_3", "add_4", "linear2", "add_5", "add_6", "relu"]], # full graph
355
def test_fuser_util(self, partition):
357
gm = symbolic_trace(m)
359
nodes_by_name = {node.name : node for node in gm.graph.nodes}
362
for node_names in partition:
363
partitions.append([nodes_by_name[name] for name in node_names])
365
fused_graph = fuse_by_partitions(gm, partitions)
367
a, b, c = torch.rand(4), torch.rand(4), torch.rand(4)
369
expected = m(a, b, c)
370
result = fused_graph(a, b, c)
372
torch.testing.assert_close(expected, result)
374
@parametrize("partition", [
375
[['add', 'add_1'], ['add_1', 'add_5', 'add_6']], # add_1 exists in multiple partitions
376
[['add', 'add_1', 'add_3']], # invalid partition: circular dependency
377
[['add_4', 'add_5']], # invalid partition: circular dependency
378
[['relu', 'add_5']], # invalid partition: circular dependency
380
def test_fuser_util_xfail(self, partition):
382
gm = symbolic_trace(m)
384
nodes_by_name = {node.name : node for node in gm.graph.nodes}
387
for node_names in partition:
388
partitions.append([nodes_by_name[name] for name in node_names])
390
with self.assertRaises(Exception):
391
fuse_by_partitions(gm, partitions)
393
def test_fuser_pass_deep_model(self):
395
traced = symbolic_trace(m)
397
supported_ops = MockOperatorSupport()
398
partitioner = CapabilityBasedPartitioner(traced,
400
allows_single_node_partition=True)
401
partitions = partitioner.propose_partitions()
406
match_placeholder: bool
408
remove_overlapping_matches: bool = True
410
class SingleNodePattern:
414
return torch.add(val, val)
421
# match_output, match_placeholder, num_matches
422
TestCase(False, False, 1),
423
TestCase(True, False, 0),
424
TestCase(False, True, 1),
425
TestCase(True, True, 0)
429
def forward(x, w1, w2):
430
m1 = torch.cat([w1, w2]).sum()
431
m2 = torch.cat([w2, w1]).sum()
432
m3 = torch.cat([m1, m2]).sum()
433
return x + torch.max(m1) + torch.max(m2) + m3
437
return torch.cat([a, b]).sum()
440
# match_output, match_placeholder, num_matches
441
TestCase(False, False, 3),
442
TestCase(True, False, 0),
443
TestCase(False, True, 2),
444
TestCase(True, True, 0)
447
class SimpleFullGraphMatching:
451
return torch.add(a, a)
456
return torch.add(a, a)
459
# match_output, match_placeholder, num_matches
460
TestCase(False, False, 1),
461
TestCase(True, False, 1),
462
TestCase(False, True, 1),
463
TestCase(True, True, 1)
466
class DiamondShapePatternTestCase:
487
# match_output, match_placeholder, num_matches
488
TestCase(False, False, 1),
489
TestCase(True, False, 1),
490
TestCase(False, True, 0),
491
TestCase(True, True, 0)
494
class NonFullyContainedMatches:
496
def forward(x, w1, w2, b1, b2):
497
# fully contained matched subgraph
498
m1 = torch.cat([w1, w2])
499
m2 = torch.cat([x, b2])
500
t0 = torch.addmm(b1, m1, m2.t())
501
t0_sum = torch.sum(t0) # use of t0 is not leaking
503
# leaking matched subgraph, m3 is leaked
504
m3 = torch.cat([w1, w2])
505
m4 = torch.cat([x, b2])
506
t1 = torch.addmm(b1, m3, m4.t())
507
m3_sum = torch.sum(m3)
509
return t0_sum, m3_sum
512
def pattern(x, w1, w2, b1, b2):
513
m1 = torch.cat([w1, w2])
514
m2 = torch.cat([x, b2])
515
return torch.addmm(b1, m1, m2.t())
518
# match_output, match_placeholder, num_matches
519
TestCase(False, False, 1),
521
TestCase(True, False, 0),
523
TestCase(False, True, 1), # leaked used of placeholder is not leaking
526
class ChainRepeatedPattern:
532
return torch.sigmoid(x)
536
return torch.sigmoid(torch.sigmoid(x))
539
# match_output, match_placeholder, num_matches
540
TestCase(False, False, 3, remove_overlapping_matches=False),
541
TestCase(False, False, 2, remove_overlapping_matches=True),
542
TestCase(True, False, 1),
543
TestCase(False, True, 1),
544
TestCase(True, True, 0)
547
class QuantizationModel:
553
x = x.to(torch.float16)
560
x = x.to(torch.float16)
564
# match_output, match_placeholder, num_matches
565
TestCase(False, False, 1),
566
TestCase(True, False, 1),
567
TestCase(False, True, 0),
568
TestCase(True, True, 0)
571
class MultipleOutputsWithDependency:
582
return b, c # outputs have data dependency
585
# match_output, match_placeholder, num_matches
586
TestCase(False, False, 1),
587
TestCase(True, False, 0),
588
TestCase(False, True, 1),
589
TestCase(True, True, 0)
592
class MultipleOutputsWithoutDependency:
597
# target subgraph to match
602
out = y.sigmoid() + z.sum()
613
# match_output, match_placeholder, num_matches
614
TestCase(False, False, 1),
615
TestCase(True, False, 0),
616
TestCase(False, True, 0),
617
TestCase(True, True, 0)
620
class MultipleOutputsMultipleOverlappingMatches:
625
# target subgraph to match
632
return z + z1 + y + y1
642
# match_output, match_placeholder, num_matches
643
TestCase(False, False, 4, remove_overlapping_matches=False),
644
TestCase(False, False, 1, remove_overlapping_matches=True),
647
class MultipleOutputsMultipleNonOverlappingMatches:
652
# target subgraph to match
661
return z + z1 + y + y1
671
# match_output, match_placeholder, num_matches
672
TestCase(False, False, 1),
675
class MultipleOutputsIdenticalAnchor:
680
# target subgraph to match
695
# match_output, match_placeholder, num_matches
696
# (False, False, 2), # FIXME: currently still matches to 2, should fix to 1
697
TestCase(True, False, 1),
698
TestCase(False, True, 0),
702
class MultipleOutputsHorizontalPattern:
707
# target subgraph to match
721
# match_output, match_placeholder, num_matches
722
TestCase(False, False, 1),
723
TestCase(True, False, 1),
724
TestCase(False, True, 0),
725
TestCase(True, True, 0)
728
class MultiOutputWithWithInvalidMatches:
731
res0 = torch.nn.functional.linear(x, torch.rand(3, 3))
732
res1 = torch.sigmoid(res0)
734
res3 = torch.sum(res2, dim=1)
738
def pattern(a, b, c):
739
lin_res = torch.nn.functional.linear(a, b)
740
mul_res = lin_res * c
741
return lin_res, mul_res
744
# match_output, match_placeholder, num_matches
745
TestCase(False, False, 0),
746
TestCase(True, False, 0),
747
TestCase(False, True, 0),
750
class QuantizationFp8Pattern:
753
cls.quantization = torch.library.Library("fp8_quantization", "DEF") # noqa: TOR901
754
cls.quantization.define("quantize_per_tensor_affine_fp8(Tensor self, int dtype, float scale) -> Tensor")
755
cls.quantization.define("dequantize_per_tensor_affine_fp8(Tensor self, int dtype, float scale) -> Tensor")
762
def forward(self, arg0_1, arg1_1):
763
qt = torch.ops.fp8_quantization
764
_scale_0 = self._scale_0
765
quantize_per_tensor_affine_fp8 = qt.quantize_per_tensor_affine_fp8(arg0_1, 0, _scale_0)
766
dequantize_per_tensor_affine_fp8 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8, 0, _scale_0)
767
_scale_1 = self._scale_0
768
quantize_per_tensor_affine_fp8_1 = qt.quantize_per_tensor_affine_fp8(arg1_1, 0, _scale_1)
769
dequantize_per_tensor_affine_fp8_1 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8_1, 0, _scale_1)
770
add = torch.ops.aten.add.Tensor(dequantize_per_tensor_affine_fp8, dequantize_per_tensor_affine_fp8_1)
771
_scale_2 = self._scale_0
772
quantize_per_tensor_affine_fp8_2 = qt.quantize_per_tensor_affine_fp8(add, 0, _scale_2)
773
dequantize_per_tensor_affine_fp8_2 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8_2, 0, _scale_2)
774
return dequantize_per_tensor_affine_fp8_2
777
def pattern(a, a_dtype, a_scale, b, b_dtype, b_scale, out_scale):
778
qt = torch.ops.fp8_quantization
779
a = qt.dequantize_per_tensor_affine_fp8(a, a_dtype, a_scale)
780
b = qt.dequantize_per_tensor_affine_fp8(b, b_dtype, b_scale)
781
output = torch.ops.aten.add.Tensor(a, b)
783
qt.dequantize_per_tensor_affine_fp8
785
output = qt.quantize_per_tensor_affine_fp8(output, a_dtype, out_scale)
789
# match_output, match_placeholder, num_matches
790
TestCase(False, False, 1),
794
# This test case is for pattern where no matching anchor is found in the target graph
795
# `anchor` is the starting point of the pattern matching, it's usually the boundary returning nodes
807
# match_output, match_placeholder, num_matches
808
TestCase(False, False, 0),
809
TestCase(True, False, 0),
810
TestCase(False, True, 0),
811
TestCase(True, True, 0)
814
@instantiate_parametrized_tests
815
class TestFXMatcherUtils(JitTestCase):
817
@parametrize("test_model", [
820
SimpleFullGraphMatching,
821
DiamondShapePatternTestCase,
822
NonFullyContainedMatches,
823
ChainRepeatedPattern,
825
MultipleOutputsWithDependency,
826
MultipleOutputsWithoutDependency,
827
MultipleOutputsMultipleOverlappingMatches,
828
MultipleOutputsMultipleNonOverlappingMatches,
829
MultipleOutputsIdenticalAnchor,
830
MultipleOutputsHorizontalPattern,
831
MultiOutputWithWithInvalidMatches,
832
QuantizationFp8Pattern,
835
def test_subgraph_matcher(self, test_model):
837
setup = getattr(test_model, "setup", None)
841
traced = symbolic_trace(test_model.forward)
842
pattern_traced = symbolic_trace(test_model.pattern)
844
for test_case in test_model.test_cases:
846
matcher = SubgraphMatcher(pattern_traced.graph,
847
match_output=test_case.match_output,
848
match_placeholder=test_case.match_placeholder,
849
remove_overlapping_matches=test_case.remove_overlapping_matches)
850
matches = matcher.match(traced.graph)
852
assert len(matches) == test_case.num_matches
854
for match in matches:
855
for node in pattern_traced.graph.nodes:
856
if not test_case.match_placeholder and node.op == "placeholder":
858
if not test_case.match_output and node.op == "output":
860
assert node in match.nodes_map
862
tearDown = getattr(test_model, "tearDown", None)
867
if __name__ == "__main__":