pytorch
395 строк · 16.8 Кб
1# Owner(s): ["module: inductor"]
2import unittest
3from unittest.mock import patch
4
5import torch
6import torch._dynamo
7import torch._dynamo.logging
8import torch._dynamo.test_case
9
10# for some reason importing functional collectives after dynamo breaks collectives handling!
11import torch.distributed._functional_collectives as _functional_collectives
12from torch._C import FileCheck
13from torch._dynamo.utils import same
14from torch._inductor import ir, scheduler
15from torch._inductor.comm_analysis import (
16baseLat,
17hwLat,
18llMaxBws,
19NCCL_ALGO,
20NCCL_HW,
21NCCL_PROTO,
22NVIDIA_GPU_TYPE,
23)
24from torch._inductor.utils import run_and_get_triton_code
25from torch.testing._internal.common_distributed import (
26_dynamo_dist_per_rank_init,
27at_least_x_gpu,
28DynamoDistributedMultiProcTestCase,
29requires_nccl,
30)
31from torch.utils._triton import has_triton
32
33
34def get_snode_runtime_for_reorder_compute_test(snode):
35# NOTE: custom cost model to show that the compute reordering algorithm is working
36# Collective kernels
37if isinstance(snode.node, ir._CollectiveKernel):
38return 100
39elif isinstance(snode.node, ir._WaitKernel):
40return 0
41# High-arithmetic-intensity compute kernels
42elif isinstance(snode.node, ir.ExternKernel):
43return 5
44# All other kernels
45return 1
46
47
48def create_grouped_node_for_allreduce_and_its_deps(snodes):
49name_to_snode = {snode.node.name: snode for snode in snodes}
50all_reduce_snodes = [
51snode
52for snode in snodes
53if isinstance(snode.node, ir._CollectiveKernel)
54and snode.node.op_overload == torch.ops._c10d_functional.all_reduce_.default
55]
56assert len(all_reduce_snodes) == 1
57all_reduce_snode = all_reduce_snodes[0]
58all_reduce_dep_snodes = [
59name_to_snode[node.name] for node in all_reduce_snode.node.inputs
60]
61assert len(all_reduce_dep_snodes) == 1
62all_reduce_dep_snode = all_reduce_dep_snodes[0]
63
64grouped_snode = scheduler.GroupedSchedulerNode.create(
65[all_reduce_dep_snode, all_reduce_snode]
66)
67new_snode_order = []
68new_snode_order.append(grouped_snode)
69for snode in snodes:
70if snode in grouped_snode.snodes:
71continue
72new_snode_order.append(snode)
73return new_snode_order
74
75
76@requires_nccl()
77class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
78"""
79Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under
80"""
81
82def get_world_trs(self):
83return {
84"tag": "",
85"ranks": list(range(self.world_size)),
86"group_size": self.world_size,
87}
88
89@property
90def world_size(self) -> int:
91# hack: no matter whether we have 2 or 3 or 4 gpus, just run on 2
92# works around issue with skipif<2 and workers with unpredictable #s gpu
93return 2
94
95@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
96@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
97# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
98@patch.object(torch._inductor.config, "compile_threads", 1)
99@patch.object(torch._inductor.config, "reorder_for_locality", False)
100@patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
101@patch.object(
102torch._inductor.config,
103"reorder_for_compute_comm_overlap_passes",
104[
105"sink_waits",
106],
107)
108def test_sink_waits(self):
109def func(a):
110ar = _functional_collectives.all_reduce(a, "sum", "0")
111b = torch.matmul(a, a)
112return torch.matmul(ar, b)
113
114with _dynamo_dist_per_rank_init(
115self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
116):
117inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
118compiled = torch.compile(func)
119code = run_and_get_triton_code(compiled, inputs)
120# Verify that the wait_tensor is sinked below the 1st matmul but
121# above the 2nd matmul.
122(
123FileCheck()
124.check("torch.ops._c10d_functional.all_reduce_.default")
125.check("extern_kernels.mm")
126.check("torch.ops._c10d_functional.wait_tensor.default")
127.check("extern_kernels.mm")
128.run(code)
129)
130out = compiled(inputs)
131correct = func(inputs)
132self.assertTrue(same(out, correct))
133
134@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
135@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
136# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
137@patch.object(torch._inductor.config, "compile_threads", 1)
138@patch.object(torch._inductor.config, "reorder_for_locality", False)
139@patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
140@patch.object(
141torch._inductor.config,
142"reorder_for_compute_comm_overlap_passes",
143[
144"raise_comms",
145],
146)
147def test_raise_comms(self):
148def func(a):
149b = torch.matmul(a, a)
150c = torch.relu(b)
151d = torch.matmul(c, c)
152e = _functional_collectives.all_reduce(b, "sum", "0")
153return torch.matmul(d, e)
154
155with _dynamo_dist_per_rank_init(
156self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
157):
158inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
159compiled = torch.compile(func)
160code = run_and_get_triton_code(compiled, inputs)
161print(code)
162# Verify that the all_reduce_ has been raised above the 2nd matmul
163# but below the 1st matmul. Note that the all_reduce_ directly
164# writes to the output buffer of the 1st matmul, which is an input
165# to the first relu. Therefore, the all_reduce_ should be scheduled
166# after the first relu.
167(
168FileCheck()
169.check("extern_kernels.mm")
170.check("triton_poi_fused_relu")
171.check("torch.ops._c10d_functional.all_reduce_.default")
172.check("extern_kernels.mm")
173.check("torch.ops._c10d_functional.wait_tensor.default")
174.check("extern_kernels.mm")
175.run(code)
176)
177out = compiled(inputs)
178correct = func(inputs)
179self.assertTrue(same(out, correct))
180
181@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
182@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
183# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
184@patch.object(torch._inductor.config, "compile_threads", 1)
185@patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
186@patch.object(
187torch._inductor.config,
188"reorder_for_compute_comm_overlap_passes",
189[
190"sink_waits",
191"raise_comms",
192],
193)
194def test_sink_waits_raise_comms(self):
195def func(a, *, tag, ranks, group_size):
196b = torch.matmul(a, a)
197c = torch.relu(b)
198d = torch.matmul(c, c)
199e = _functional_collectives.all_reduce(b, "sum", "0")
200f = torch.relu(d)
201g = torch.matmul(f, f)
202return torch.mm(e, g)
203
204with _dynamo_dist_per_rank_init(
205self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
206):
207inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
208compiled = torch.compile(func)
209code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
210# Things to verify:
211# - The clone prologue of the all_reduce_ should not be fused with
212# any relus.
213# - The all_reduce_ and its prologue should be raised above the 2nd
214# matmul but below the 1st matmul.
215# - The wait_tensor should be sinked below the 3rd matmul but above
216# the 4th matmul.
217(
218FileCheck()
219.check("extern_kernels.mm")
220.check("triton_poi_fused_all_reduce_0")
221.check("torch.ops._c10d_functional.all_reduce_.default")
222.check("triton_poi_fused_relu")
223.check("extern_kernels.mm")
224.check("triton_poi_fused_relu")
225.check("extern_kernels.mm")
226.check("torch.ops._c10d_functional.wait_tensor.default")
227.check("extern_kernels.mm")
228.run(code)
229)
230out = compiled(inputs, **self.get_world_trs())
231correct = func(inputs, **self.get_world_trs())
232self.assertTrue(same(out, correct))
233
234@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
235@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
236# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
237@patch.object(torch._inductor.config, "compile_threads", 1)
238@patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
239@patch.object(
240torch._inductor.config,
241"reorder_for_compute_comm_overlap_passes",
242[
243"reorder_compute_for_overlap",
244],
245)
246def test_reorder_compute_for_overlap(self):
247def func(a, *, tag, ranks, group_size):
248ar = _functional_collectives.all_reduce(a, "sum", ranks, tag)
249g = torch.matmul(a, a)
250c = torch.relu(a)
251d = torch.matmul(c, c)
252f = d * c * ar
253fr = _functional_collectives.all_reduce(f, "sum", ranks, tag)
254e = torch.matmul(d + ar + fr, g)
255return (e,)
256
257with _dynamo_dist_per_rank_init(
258self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
259):
260inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
261compiled = torch.compile(func)
262code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
263# NOTE: after scheduling the first all_reduce:
264# 1. we first schedule the ops (c and d) that ARE required for second all_reduce but DO NOT depend on first all_reduce.
265# 2. then, we schedule the ops (g) that ARE NOT required for second all_reduce and DO NOT depend on first all_reduce.
266# 3. then, we schedule the ops (f) that ARE required for second all_reduce and DO depend on first all_reduce.
267# and then, we schedule the second all_reduce. And then schedule all ops that depend on second all_reduce.
268(
269FileCheck()
270.check("torch.ops._c10d_functional.all_reduce_.default")
271.check("triton_poi_fused_relu")
272.check("extern_kernels.mm")
273.check("extern_kernels.mm")
274.check("torch.ops._c10d_functional.wait_tensor.default")
275.check("triton_poi_fused_mul")
276.check("torch.ops._c10d_functional.all_reduce_.default")
277.check("torch.ops._c10d_functional.wait_tensor.default")
278.check("triton_poi_fused_add")
279.check("extern_kernels.mm")
280.run(code)
281)
282out = compiled(inputs, **self.get_world_trs())
283correct = func(inputs, **self.get_world_trs())
284self.assertTrue(same(out, correct))
285
286@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
287@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
288# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
289@patch.object(torch._inductor.config, "compile_threads", 1)
290@patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
291@patch.object(
292torch._inductor.config,
293"reorder_for_compute_comm_overlap_passes",
294[
295"reorder_compute_for_overlap",
296],
297)
298@patch.object(
299torch._inductor.config,
300"estimate_op_runtime",
301get_snode_runtime_for_reorder_compute_test,
302)
303def test_reorder_compute_for_overlap_custom_runtime_estimation(self):
304def func(a, *, tag, ranks, group_size):
305ar = _functional_collectives.all_reduce(a, "sum", ranks, tag)
306g = torch.matmul(a, a)
307c = torch.relu(a)
308d = torch.matmul(c, c)
309f = d * c * ar
310fr = _functional_collectives.all_reduce(f, "sum", ranks, tag)
311e = torch.matmul(d + ar + fr, g)
312return (e,)
313
314with _dynamo_dist_per_rank_init(
315self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
316):
317inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
318compiled = torch.compile(func)
319code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
320# NOTE: after scheduling the first all_reduce:
321# 1. we first schedule the ops (c and d) that ARE required for second all_reduce but DO NOT depend on first all_reduce.
322# 2. then, we schedule the ops (g) that ARE NOT required for second all_reduce and DO NOT depend on first all_reduce.
323# 3. then, we schedule the ops (f) that ARE required for second all_reduce and DO depend on first all_reduce.
324# and then, we schedule the second all_reduce. And then schedule all ops that depend on second all_reduce.
325(
326FileCheck()
327.check("torch.ops._c10d_functional.all_reduce_.default")
328.check("triton_poi_fused_relu")
329.check("extern_kernels.mm")
330.check("extern_kernels.mm")
331.check("torch.ops._c10d_functional.wait_tensor.default")
332.check("triton_poi_fused_mul")
333.check("torch.ops._c10d_functional.all_reduce_.default")
334.check("torch.ops._c10d_functional.wait_tensor.default")
335.check("triton_poi_fused_add")
336.check("extern_kernels.mm")
337.run(code)
338)
339out = compiled(inputs, **self.get_world_trs())
340correct = func(inputs, **self.get_world_trs())
341self.assertTrue(same(out, correct))
342
343@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
344# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
345@patch.object(torch._inductor.config, "compile_threads", 1)
346@patch.object(
347torch._inductor.config,
348"_pre_fusion_custom_pass",
349create_grouped_node_for_allreduce_and_its_deps,
350)
351def test_grouped_scheduler_node(self):
352def func(a, *, tag, ranks, group_size):
353add = a + a
354div = add / a
355ar = _functional_collectives.all_reduce(div, "sum", ranks, tag)
356# Normally, we would fuse `add = a + a`, `div = add / a` and `mul = a * a` together into a single fused op,
357# but here in this unit test, we intentionally put `add`, `div` and `ar` computation
358# into a GroupedSchedulerNode, which prevents them from being fused with any other ops.
359mul = a * a
360mm = torch.matmul(mul, ar)
361return (mm,)
362
363with _dynamo_dist_per_rank_init(
364self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
365):
366inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
367compiled = torch.compile(func)
368code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
369# Expectations:
370# 1. `add = a + a` and `div = add / a` are still fused, which means fusion
371# still happens among nodes within a GroupedSchedulerNode.
372# 2. `mul = a * a` is not fused with `add` or `div`, because the latter two are within
373# GroupedSchedulerNode and thus are prevented from being fused with any outside ops.
374FileCheck().check("triton_poi_fused_add_div_0.").check(
375"_c10d_functional.all_reduce_."
376).check("triton_poi_fused_mul_1.").run(code)
377out = compiled(inputs, **self.get_world_trs())
378correct = func(inputs, **self.get_world_trs())
379self.assertTrue(same(out, correct))
380
381def test_nccl_heuristics(self):
382assert len(baseLat) == len(NCCL_ALGO)
383assert all(len(x) == len(NCCL_PROTO) for x in baseLat)
384
385assert len(hwLat) == len(NCCL_HW)
386assert all(len(x) == len(NCCL_ALGO) for x in hwLat)
387assert all(len(y) == len(NCCL_PROTO) for x in hwLat for y in x)
388
389assert len(llMaxBws) == len(NVIDIA_GPU_TYPE)
390
391
392if __name__ == "__main__":
393from torch._dynamo.test_case import run_tests
394
395run_tests()
396