pytorch
1134 строки · 43.5 Кб
1# Owner(s): ["module: dynamo"]
2import functools
3import unittest
4from unittest.mock import patch
5
6import torch
7import torch._dynamo
8import torch._dynamo.logging
9import torch._dynamo.test_case
10
11# for some reason importing functional collectives after dynamo breaks collectives handling!
12import torch.distributed._functional_collectives as _functional_collectives
13from torch._C import FileCheck
14from torch._dynamo.testing import CompileCounter
15from torch._dynamo.utils import same
16from torch._inductor.compile_fx import compile_fx as inductor_compile_fx
17from torch._inductor.utils import run_and_get_triton_code
18from torch.distributed.distributed_c10d import GroupMember
19from torch.fx.experimental.proxy_tensor import make_fx
20from torch.testing._internal.common_distributed import (
21_dynamo_dist_per_rank_init,
22DynamoDistributedMultiProcTestCase,
23DynamoDistributedSingleProcTestCase,
24requires_nccl,
25skip_if_lt_x_gpu,
26)
27from torch.testing._internal.common_utils import (
28instantiate_parametrized_tests,
29parametrize,
30requires_cuda,
31)
32from torch.utils._triton import has_triton
33
34
35def _tolist_with_constrain_as_size(tensor):
36lst = tensor.tolist()
37for elem in lst:
38torch._check_is_size(elem)
39return lst
40
41
42@requires_nccl()
43class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
44"""
45Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under
46"""
47
48def get_world_trs(self):
49return {
50"tag": "",
51"ranks": list(range(self.world_size)),
52"group_size": self.world_size,
53}
54
55@property
56def world_size(self) -> int:
57# hack: no matter whether we have 2 or 3 or 4 gpus, just run on 2
58# works around issue with skipif<2 and workers with unpredictable #s gpu
59return 2
60
61@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
62@skip_if_lt_x_gpu(2)
63def test_broadcast_inductor(self):
64"""
65Testing if broadcast works correctly when using inductor
66"""
67
68def example(tensor, src, *, tag, ranks, group_size):
69res = torch.ops.c10d_functional.broadcast(
70tensor, src, tag, ranks, group_size
71)
72res = torch.ops.c10d_functional.wait_tensor(res)
73return res
74
75def compile(func, example_inputs):
76graph = make_fx(func)(*example_inputs)
77return inductor_compile_fx(graph, example_inputs)
78
79with _dynamo_dist_per_rank_init(self.rank, self.world_size):
80example = functools.partial(
81example,
82**self.get_world_trs(),
83)
84t = torch.randn(4, 4, device="cuda")
85inputs = (t if self.rank == 0 else torch.zeros(4, 4, device="cuda"), 0)
86eager_out = example(*inputs)
87self.assertTrue(same(t, eager_out))
88
89compiled_func = compile(example, inputs)
90compiled_out = compiled_func(*inputs)
91self.assertTrue(same(eager_out, compiled_out))
92
93@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
94@skip_if_lt_x_gpu(2)
95def test_allreduce_inductor(self):
96"""
97This is matmul/cat/allreduce is a pattern we aim to optimize.
98"""
99
100def matmul_cat_col(a, b, c, d, e, f, *, tag, ranks, group_size):
101x = torch.matmul(a, b)
102y = torch.matmul(c, d)
103z = torch.cat((x, y))
104ar = torch.ops.c10d_functional.all_reduce(z, "sum", tag, ranks, group_size)
105g = torch.matmul(e, f)
106ar = torch.ops.c10d_functional.wait_tensor(ar)
107out = torch.add(ar, g.repeat(2, 1))
108return (out,)
109
110def compile(func, example_inputs):
111graph = make_fx(func)(*example_inputs)
112return inductor_compile_fx(graph, example_inputs)
113
114with _dynamo_dist_per_rank_init(self.rank, self.world_size):
115matmul_cat_col = functools.partial(
116matmul_cat_col,
117**self.get_world_trs(),
118)
119inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 6
120
121eager_out = matmul_cat_col(*inputs)
122compiled_matmul_cat_col = compile(matmul_cat_col, inputs)
123inductor_out = compiled_matmul_cat_col(*inputs)
124self.assertTrue(same(eager_out, inductor_out, tol=0.001))
125
126@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
127@skip_if_lt_x_gpu(2)
128def test_allreduce_inductor_cudagraph_trees(self):
129"""
130Tests whether cudagraph trees support all_reduce from nccl
131"""
132import torch.distributed as dist
133
134# dist.all_reduce is an inplace op in eager mode but a functionanlized op in compiled mode.
135# so we define eager_func and func separately for the same semantic.
136def eager_func(x):
137y = x * x
138dist.all_reduce(y, op=dist.ReduceOp.SUM)
139x = torch.nn.functional.silu(x)
140return x * y
141
142def func(x):
143y = x * x
144y = dist.all_reduce(y, op=dist.ReduceOp.SUM)
145x = torch.nn.functional.silu(x)
146return x * y
147
148options = {
149"triton.cudagraphs": True,
150"triton.cudagraph_trees": True,
151}
152
153with _dynamo_dist_per_rank_init(self.rank, self.world_size):
154compiled_func = torch.compile(
155func, backend="inductor", fullgraph=True, options=options, dynamic=None
156)
157
158for nelem in [1024, 2048, 4096]:
159x = torch.randn(nelem, device="cuda", dtype=torch.bfloat16)
160golden_out = eager_func(x)
161
162for _ in range(3):
163compiled_out = compiled_func(x)
164self.assertEqual(golden_out, compiled_out)
165
166def test_c10d_functional_tagged_pt2_compliant(self):
167op = torch.ops._c10d_functional.all_reduce.default
168self.assertIn(torch.Tag.pt2_compliant_tag, op.tags)
169op = torch.ops.c10d_functional.all_reduce.default
170self.assertIn(torch.Tag.pt2_compliant_tag, op.tags)
171
172@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
173@skip_if_lt_x_gpu(2)
174def test_eager_allreduce_inductor_wait(self):
175def eager_func(a, b, c, d, *, tag, ranks, group_size):
176x = torch.matmul(a, b)
177y = torch.matmul(c, d)
178z = torch.cat((x, y))
179ar = torch.ops.c10d_functional.all_reduce(z, "sum", tag, ranks, group_size)
180return ar
181
182def inductor_func(ar, e, f):
183g = torch.matmul(e, f)
184ar = torch.ops.c10d_functional.wait_tensor(ar)
185out = torch.add(ar, g.repeat(2, 1))
186return (out,)
187
188def compile(func, example_inputs):
189graph = make_fx(func)(*example_inputs)
190return inductor_compile_fx(graph, example_inputs)
191
192with _dynamo_dist_per_rank_init(self.rank, self.world_size):
193eager_func = functools.partial(
194eager_func,
195**self.get_world_trs(),
196)
197eager_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 4
198inductor_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2
199
200eager_out = inductor_func(eager_func(*eager_inputs), *inductor_inputs)
201compiled_inductor_func = compile(
202inductor_func, [eager_func(*eager_inputs)] + list(inductor_inputs)
203)
204inductor_out = compiled_inductor_func(
205eager_func(*eager_inputs), *inductor_inputs
206)
207print(f"eager_out, {eager_out}")
208print(f"inductor_out, {inductor_out}")
209self.assertTrue(same(eager_out, inductor_out, tol=0.001))
210
211@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
212@skip_if_lt_x_gpu(2)
213def test_inductor_allreduce_eager_wait(self):
214def inductor_func(a, b, c, d, *, tag, ranks, group_size):
215x = torch.matmul(a, b)
216y = torch.matmul(c, d)
217z = torch.cat((x, y))
218ar = torch.ops.c10d_functional.all_reduce(z, "sum", tag, ranks, group_size)
219return ar
220
221def eager_func(ar, e, f):
222g = torch.matmul(e, f)
223ar = torch.ops.c10d_functional.wait_tensor(ar)
224out = torch.add(ar, g.repeat(2, 1))
225return (out,)
226
227def compile(func, example_inputs):
228graph = make_fx(func)(*example_inputs)
229return inductor_compile_fx(graph, example_inputs)
230
231with _dynamo_dist_per_rank_init(self.rank, self.world_size):
232inductor_func = functools.partial(
233inductor_func,
234**self.get_world_trs(),
235)
236inductor_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 4
237eager_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2
238
239eager_out = eager_func(inductor_func(*inductor_inputs), *eager_inputs)
240compiled_inductor_func = compile(inductor_func, inductor_inputs)
241inductor_out = eager_func(
242compiled_inductor_func(*inductor_inputs), *eager_inputs
243)
244self.assertTrue(same(eager_out, inductor_out, tol=0.001))
245
246@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
247@skip_if_lt_x_gpu(2)
248@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
249def test_allreduce_input_buffer_reuse(self):
250def func(a, *, tag, ranks, group_size):
251ar = _functional_collectives.all_reduce(a, "sum", ranks, tag)
252c = torch.relu(a)
253d = torch.matmul(c, c)
254e = d + ar
255return (e,)
256
257with _dynamo_dist_per_rank_init(self.rank, self.world_size):
258inputs = torch.ones(4, 4, device="cuda") + self.rank
259compiled = torch.compile(func)
260out = compiled(inputs, **self.get_world_trs())
261correct = func(inputs, **self.get_world_trs())
262self.assertTrue(same(out, correct))
263
264@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
265@skip_if_lt_x_gpu(2)
266def test_permute_tensor(self):
267def func(tensor, src_dst_pairs, *, tag, ranks, group_size):
268return _functional_collectives.permute_tensor(
269tensor, src_dst_pairs, ranks, tag
270)
271
272with _dynamo_dist_per_rank_init(self.rank, self.world_size):
273inputs = (
274# rank0: [0., 1.], rank1: [2., 3.]
275torch.arange(2, dtype=torch.float32, device="cuda") + 2 * self.rank,
276[1, 0],
277)
278compiled = torch.compile(func)
279out = compiled(*inputs, **self.get_world_trs())
280correct = func(*inputs, **self.get_world_trs())
281self.assertTrue(same(out, correct))
282
283# rank0: [2., 3.], rank1: [0., 1.]
284expected = torch.arange(2, dtype=torch.float32, device="cuda") + 2 * (
285(self.rank - 1 + self.world_size) % self.world_size
286)
287self.assertEqual(out, expected)
288self.assertEqual(correct, expected)
289
290@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
291@skip_if_lt_x_gpu(2)
292@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
293def test_allgather_output_buffer_reuse(self):
294class Model(torch.nn.Module):
295def __init__(self, *args, **kwargs) -> None:
296super().__init__(*args, **kwargs)
297self.emb = torch.nn.Embedding(4, 4)
298
299def forward(self, x, world_size, tag, ranks, group_size):
300y = self.emb(x)
301last_dim = y.dim() - 1
302res = _functional_collectives.all_gather_tensor(y, 0, ranks, tag)
303out = torch.cat(torch.chunk(res, world_size, dim=0), dim=last_dim)
304return out
305
306with _dynamo_dist_per_rank_init(self.rank, self.world_size):
307model = Model().cuda()
308model_compiled = torch.compile(model)
309inp = torch.tensor([[2, 1, 3, 0]], dtype=torch.long, device="cuda")
310out = model_compiled(inp, self.world_size, **self.get_world_trs())
311correct = model(inp, self.world_size, **self.get_world_trs())
312self.assertTrue(same(out, correct))
313
314@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
315@skip_if_lt_x_gpu(2)
316def test_allgather_contiguous_input(self):
317class Model(torch.nn.Module):
318def __init__(self, *args, **kwargs) -> None:
319super().__init__(*args, **kwargs)
320self.emb = torch.nn.Embedding(4, 4)
321
322def forward(self, x, world_size, tag, ranks, group_size):
323y = self.emb(x)
324last_dim = y.dim() - 1
325y = y.transpose_(0, last_dim).contiguous()
326res = _functional_collectives.all_gather_tensor(y, 0, ranks, tag)
327out = y.transpose_(0, last_dim).contiguous()
328return out
329
330with _dynamo_dist_per_rank_init(self.rank, self.world_size):
331model = Model().cuda()
332model_compiled = torch.compile(model)
333inp = torch.tensor([[2, 1, 3, 0]], dtype=torch.long, device="cuda")
334out = model_compiled(inp, self.world_size, **self.get_world_trs())
335correct = model(inp, self.world_size, **self.get_world_trs())
336self.assertTrue(same(out, correct))
337
338@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
339@skip_if_lt_x_gpu(2)
340def test_allgather_into_tensor_inductor(self):
341"""
342This is matmul/cat/allreduce is a pattern we aim to optimize.
343"""
344
345def example(a, b, *, tag, ranks, group_size):
346c = torch.matmul(a, b)
347ag = torch.ops.c10d_functional.all_gather_into_tensor(
348c, tag, ranks, group_size
349)
350ag = torch.ops.c10d_functional.wait_tensor(ag)
351return (ag,)
352
353def compile(func, example_inputs):
354graph = make_fx(func)(*example_inputs)
355return inductor_compile_fx(graph, example_inputs)
356
357with _dynamo_dist_per_rank_init(self.rank, self.world_size):
358example = functools.partial(
359example,
360**self.get_world_trs(),
361)
362inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2
363
364eager_out = example(*inputs)
365compiled_matmul_cat_col = compile(example, inputs)
366inductor_out = compiled_matmul_cat_col(*inputs)
367self.assertTrue(same(eager_out, inductor_out, tol=0.001))
368
369@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
370@skip_if_lt_x_gpu(2)
371def test_reduce_scatter_tensor_inductor(self):
372def example(a, b, *, tag, ranks, group_size):
373c = torch.matmul(a, b)
374ag = torch.ops.c10d_functional.reduce_scatter_tensor(
375c, "sum", tag, ranks, group_size
376)
377ag = torch.ops.c10d_functional.wait_tensor(ag)
378return (ag,)
379
380def compile(func, example_inputs):
381graph = make_fx(func)(*example_inputs)
382return inductor_compile_fx(graph, example_inputs)
383
384with _dynamo_dist_per_rank_init(self.rank, self.world_size):
385example = functools.partial(
386example,
387**self.get_world_trs(),
388)
389inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2
390
391eager_out = example(*inputs)
392compiled_fn = compile(example, inputs)
393inductor_out = compiled_fn(*inputs)
394self.assertTrue(same(eager_out, inductor_out, tol=0.001))
395
396@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
397@skip_if_lt_x_gpu(2)
398@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
399def test_all_to_all_single_inductor(self):
400def example(
401inp,
402input_split_sizes_tensor,
403output_split_sizes_tensor,
404*,
405tag,
406ranks,
407group_size,
408):
409input_split_sizes = _tolist_with_constrain_as_size(input_split_sizes_tensor)
410output_split_sizes = _tolist_with_constrain_as_size(
411output_split_sizes_tensor
412)
413a2a = torch.ops.c10d_functional.all_to_all_single(
414inp,
415output_split_sizes,
416input_split_sizes,
417tag,
418ranks,
419group_size,
420)
421a2a = torch.ops.c10d_functional.wait_tensor(a2a)
422out = a2a / a2a.sum(dim=0)
423return out
424
425with _dynamo_dist_per_rank_init(
426self.rank, self.world_size
427), torch._dynamo.config.patch(
428dynamic_shapes=True,
429capture_dynamic_output_shape_ops=True,
430capture_scalar_outputs=True,
431):
432row = self.world_size * (self.rank + 1) * (self.world_size + 1) / 2
433input_split_sizes_tensor = torch.tensor(
434[(i + 1) * (self.rank + 1) for i in range(self.world_size)],
435dtype=torch.int64,
436)
437output_split_sizes_tensor = torch.tensor(
438[(i + 1) * (self.rank + 1) for i in range(self.world_size)],
439dtype=torch.int64,
440)
441inputs = (
442torch.ones(int(row), 5, device="cuda") * (self.rank + 1),
443input_split_sizes_tensor,
444output_split_sizes_tensor,
445)
446trs = self.get_world_trs()
447
448compiled_fn = torch.compile(example, fullgraph=True, dynamic=True)
449code = run_and_get_triton_code(compiled_fn, *inputs, **trs)
450(
451FileCheck()
452.check_regex(
453"torch.ops._c10d_functional.all_to_all_single.default\\("
454"arg\\d+_\\d+, "
455"\\[u\\d+, u\\d+\\], "
456"\\[u\\d+, u\\d+\\]"
457)
458.run(code)
459)
460
461eager_out = example(*inputs, **trs)
462inductor_out = compiled_fn(*inputs, **trs)
463self.assertTrue(same(eager_out, inductor_out, tol=0.001))
464
465@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
466@skip_if_lt_x_gpu(2)
467def test_all_to_all_single_inductor_split_sizes_none(self):
468def example(inp, *, tag, ranks, group_size):
469a2a = torch.ops.c10d_functional.all_to_all_single(
470inp,
471None,
472None,
473tag,
474ranks,
475group_size,
476)
477a2a = torch.ops.c10d_functional.wait_tensor(a2a)
478out = a2a / a2a.sum(dim=0)
479return out
480
481with _dynamo_dist_per_rank_init(self.rank, self.world_size):
482inputs = (
483torch.ones(self.world_size, self.world_size, device="cuda")
484* (self.rank + 1),
485)
486trs = self.get_world_trs()
487
488compiled_fn = torch.compile(example, fullgraph=True, dynamic=True)
489code = run_and_get_triton_code(compiled_fn, *inputs, **trs)
490(
491FileCheck()
492.check_regex(
493"torch.ops._c10d_functional.all_to_all_single.default\\("
494"arg\\d+_\\d+, "
495"\\[\\(s\\d+ // \\d\\), \\(s\\d+ // \\d\\)\\], "
496"\\[\\(s\\d+ // \\d\\), \\(s\\d+ // \\d\\)\\]"
497)
498.run(code)
499)
500
501eager_out = example(*inputs, **trs)
502inductor_out = compiled_fn(*inputs, **trs)
503self.assertTrue(same(eager_out, inductor_out, tol=0.001))
504
505
506@instantiate_parametrized_tests
507@requires_nccl()
508@requires_cuda
509class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
510"""
511Prefer single-proc test runner for basic tests as it is easier to work with.
512"""
513
514def get_world_trs(self, world_size=1):
515return {
516"tag": "",
517"ranks": list(range(world_size)),
518"group_size": world_size,
519}
520
521@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
522@torch._inductor.config.patch(debug=True)
523def test_inductor_single_op(self):
524def func(inp, *, tag, ranks, group_size):
525ar = torch.ops.c10d_functional.all_reduce(
526inp, "sum", tag, ranks, group_size
527)
528ar = torch.ops.c10d_functional.wait_tensor(ar)
529return ar
530
531inputs = torch.ones(4, 4, device="cuda")
532
533compiled = torch.compile(func)
534out = compiled(inputs, **self.get_world_trs())
535code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
536# NOTE: Make sure we are not unneccessarily copying the outputs of
537# wait_tensors before they are returned from the graph.
538(
539FileCheck()
540.check("buf0 = empty_strided")
541.check(".run(arg0_1, buf0, 16")
542.check("torch.ops._c10d_functional.all_reduce_.default(buf0")
543.check("torch.ops._c10d_functional.wait_tensor.default(buf0")
544.check("return (buf0")
545.run(code)
546)
547correct = func(inputs, **self.get_world_trs())
548self.assertTrue(same(out, correct))
549
550@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
551@torch._inductor.config.patch(debug=True)
552def test_inductor_steal_buffer(self):
553"""
554it's ok and optimal if inductor allreduce mutates the buffer of an intermediate
555that isn't going to be used again
556"""
557
558def func(inp, *, tag, ranks, group_size):
559x = inp + 1
560ar = torch.ops.c10d_functional.all_reduce(x, "sum", tag, ranks, group_size)
561ar = torch.ops.c10d_functional.wait_tensor(ar)
562# ensure other is not incorrectly aliasing ar's buffer
563other = torch.ones_like(inp) + 22
564return ar, other
565
566inputs = torch.ones(4, 4, device="cuda")
567
568compiled = torch.compile(func)
569code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
570(
571FileCheck()
572.check("buf0 = empty_strided")
573.check(".run(arg0_1, buf0")
574.check("torch.ops._c10d_functional.all_reduce_.default(buf0")
575.check("torch.ops._c10d_functional.wait_tensor.default(buf0")
576.check("buf5 = empty_strided")
577.check(".run(buf5, 16")
578.check("return (buf0, buf5")
579.run(code)
580)
581out = compiled(inputs, **self.get_world_trs())
582correct = func(inputs, **self.get_world_trs())
583self.assertTrue(same(out, correct))
584
585@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
586@torch._inductor.config.patch({"debug": True, "triton.descriptive_names": False})
587def test_inductor_doesnt_mutate_shared(self):
588"""
589make sure that an intermediate that's going to be reuse isn't mutated unless copied
590"""
591
592def func(inp, *, tag, ranks, group_size):
593x = inp + 1
594ar = torch.ops.c10d_functional.all_reduce(x, "sum", tag, ranks, group_size)
595y = x + 2
596ar = torch.ops.c10d_functional.wait_tensor(ar)
597# ensure other is not incorrectly aliasing ar's buffer
598other = torch.ones_like(inp) + 22
599return ar, y, other
600
601inputs = torch.ones(4, 4, device="cuda")
602
603compiled = torch.compile(func)
604code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
605# NOTE: Make sure we are not unneccessarily copying the outputs of
606# wait_tensors before they are returned from the graph.
607(
608FileCheck()
609.check("buf0 = empty_strided")
610.check("buf5 = empty_strided")
611.check(".run(arg0_1, buf0, buf5, 16")
612.check("torch.ops._c10d_functional.all_reduce_.default(buf0")
613.check("torch.ops._c10d_functional.wait_tensor.default(buf0")
614.check("buf6 = empty_strided")
615.check(".run(buf6, 16")
616.check("return (buf0, buf5, buf6")
617.run(code)
618)
619out = compiled(inputs, **self.get_world_trs())
620correct = func(inputs, **self.get_world_trs())
621self.assertTrue(same(out, correct))
622
623def test_dynamo_trace_allreduce(self):
624def func(inp):
625ar = _functional_collectives.all_reduce(inp, "sum", "0")
626return ar
627
628inputs = torch.ones(4, 4, device="cuda")
629counter = CompileCounter()
630compiled = torch.compile(func, backend=counter)
631out = compiled(inputs)
632correct = func(inputs)
633self.assertEqual(counter.frame_count, 1)
634
635# should test more precisely, but the 2 is supposed to be (all_reduce, wait)
636self.assertEqual(counter.op_count, 2)
637self.assertTrue(same(out, correct))
638
639def test_dynamo_trace_all_gather_tensor(self):
640def func(inp):
641ar = _functional_collectives.all_gather_tensor(inp, 0, "0")
642return ar
643
644inputs = torch.ones(4, 4, device="cuda")
645counter = CompileCounter()
646compiled = torch.compile(func, backend=counter)
647out = compiled(inputs)
648correct = func(inputs)
649self.assertEqual(counter.frame_count, 1)
650
651# should test more precisely, but the 2 is supposed to be (all_gather, wait)
652self.assertEqual(counter.op_count, 2)
653self.assertTrue(same(out, correct))
654
655def test_dynamo_trace_all_gather_tensor_pg(self):
656def func(inp, *, pg):
657ar = _functional_collectives.all_gather_tensor(inp, 0, pg)
658return ar
659
660inputs = torch.ones(4, 4, device=self.device)
661counter = CompileCounter()
662compiled = torch.compile(func, backend=counter, fullgraph=True)
663out = compiled(inputs, pg=GroupMember.WORLD)
664correct = func(inputs, pg=GroupMember.WORLD)
665self.assertEqual(counter.frame_count, 1)
666
667# should test more precisely, but the 2 is supposed to be (all_gather, wait)
668self.assertEqual(counter.op_count, 2)
669self.assertTrue(same(out, correct))
670
671def test_dynamo_rewrite_dist_all_gather(self):
672def func(inp, out, *, pg):
673torch.distributed.all_gather_into_tensor(
674out,
675inp,
676pg,
677)
678
679local_size = [4, 4]
680# single-proc test
681global_size = local_size
682
683inputs = torch.ones(local_size, device=self.device)
684outputs = torch.empty(global_size, device=self.device)
685correct_outputs = torch.empty(global_size, device=self.device)
686counter = CompileCounter()
687compiled = torch.compile(func, backend=counter, fullgraph=True)
688compiled(inputs, outputs, pg=GroupMember.WORLD)
689func(inputs, correct_outputs, pg=GroupMember.WORLD)
690assert counter.frame_count == 1
691
692# should test more precisely, but the 3 is supposed to be (all_gather, wait, copy_)
693assert counter.op_count == 3
694assert same(outputs, correct_outputs)
695
696def test_dynamo_rewrite_dist_all_gather_list(self):
697def func(inp, out, *, pg):
698torch.distributed.all_gather(
699out,
700inp,
701pg,
702)
703
704local_size = [4, 4]
705# single-proc test
706global_size = local_size
707
708inputs = torch.ones(local_size, device=self.device)
709outputs = [torch.empty(global_size, device=self.device)]
710correct_outputs = [torch.empty(global_size, device=self.device)]
711counter = CompileCounter()
712compiled = torch.compile(func, backend=counter, fullgraph=True)
713compiled(inputs, outputs, pg=GroupMember.WORLD)
714func(inputs, correct_outputs, pg=GroupMember.WORLD)
715assert counter.frame_count == 1
716assert same(outputs, correct_outputs)
717
718def test_dynamo_rewrite_dist_all_gather_args_match(self):
719# Duplicated most of the structure from test_dynamo_rewrite_dist_all_gather
720# except uses kwargs to ensure rewrite has matching arg names
721def func(inp, out, *, pg):
722torch.distributed.all_gather_into_tensor(
723output_tensor=out,
724input_tensor=inp,
725group=pg,
726async_op=False,
727)
728
729local_size = [4, 4]
730# single-proc test
731global_size = local_size
732
733inputs = torch.ones(local_size, device=self.device)
734outputs = torch.empty(global_size, device=self.device)
735correct_outputs = torch.empty(global_size, device=self.device)
736counter = CompileCounter()
737compiled = torch.compile(func, backend=counter, fullgraph=True)
738compiled(inputs, outputs, pg=GroupMember.WORLD)
739func(inputs, correct_outputs, pg=GroupMember.WORLD)
740assert counter.frame_count == 1
741
742# should test more precisely, but the 3 is supposed to be (all_gather, wait, copy_)
743assert counter.op_count == 3
744assert same(outputs, correct_outputs)
745
746def test_dynamo_rewrite_dist_reduce_scatter(self):
747def func(inp, out, *, pg):
748torch.distributed.reduce_scatter_tensor(
749out,
750inp,
751group=pg,
752)
753
754local_size = [4, 4]
755# single-proc test
756global_size = local_size
757
758inputs = torch.ones(local_size, device=self.device)
759outputs = torch.empty(global_size, device=self.device)
760correct_outputs = torch.empty(global_size, device=self.device)
761counter = CompileCounter()
762compiled = torch.compile(func, backend=counter, fullgraph=True)
763compiled(inputs, outputs, pg=GroupMember.WORLD)
764func(inputs, correct_outputs, pg=GroupMember.WORLD)
765assert counter.frame_count == 1
766
767# should test more precisely, but the 3 is supposed to be (reduce_scatter, wait, copy_)
768assert counter.op_count == 3
769assert same(outputs, correct_outputs)
770
771@parametrize(
772"pg_mode",
773[
774"positional",
775"positional_none",
776"kwargs",
777"kwargs_none",
778"unspecified",
779],
780)
781def test_dynamo_rewrite_dist_allreduce(self, pg_mode):
782def func(tensor, *args, **kwargs):
783torch.distributed.all_reduce(
784tensor,
785*args,
786**kwargs,
787)
788
789counter = CompileCounter()
790compiled = torch.compile(func, backend=counter, fullgraph=True)
791
792args = []
793kwargs = {}
794
795if pg_mode == "positional":
796args.append(torch.distributed.ReduceOp.MAX)
797args.append(GroupMember.WORLD)
798elif pg_mode == "positional_none":
799args.append(torch.distributed.ReduceOp.MAX)
800args.append(None)
801elif pg_mode == "kwargs":
802kwargs["group"] = GroupMember.WORLD
803elif pg_mode == "kwargs_none":
804kwargs["group"] = None
805else:
806assert pg_mode == "unspecified"
807
808inputs_compiled = torch.ones(2, device=self.device)
809inputs_eager = torch.ones(2, device=self.device)
810
811compiled(inputs_compiled, *args, **kwargs)
812func(inputs_eager, *args, **kwargs)
813
814assert counter.frame_count == 1
815# should test more precisely, but the 3 is supposed to be (all_reduce, wait, copy_)
816assert counter.op_count == 3
817assert same(inputs_compiled, inputs_eager)
818
819def test_dynamo_rewrite_dist_all_to_all_single(self):
820def func(output, input, pg):
821torch.distributed.all_to_all_single(output, input, group=pg)
822
823counter = CompileCounter()
824compiled = torch.compile(func, backend=counter, fullgraph=True)
825
826input_compiled = torch.ones(2, device=self.device)
827input_eager = torch.ones(2, device=self.device)
828output_compiled = torch.empty(2, device=self.device)
829output_eager = torch.empty(2, device=self.device)
830
831compiled(output_compiled, input_compiled, GroupMember.WORLD)
832func(output_eager, input_eager, GroupMember.WORLD)
833
834assert counter.frame_count == 1
835assert same(output_compiled, output_eager)
836
837@parametrize(
838"reduce_op",
839[
840torch.distributed.ReduceOp.SUM,
841torch.distributed.ReduceOp.AVG,
842torch.distributed.ReduceOp.PRODUCT,
843torch.distributed.ReduceOp.MIN,
844torch.distributed.ReduceOp.MAX,
845],
846)
847def test_dynamo_rewrite_dist_allreduce_reduce_op(self, reduce_op):
848from torch.distributed._functional_collectives import REDUCE_OP_TO_STR
849
850def verify_rewrite(gm, _):
851ar_nodes = []
852for node in gm.graph.nodes:
853if node.target in [
854torch.ops.c10d_functional.all_reduce,
855torch.ops._c10d_functional.all_reduce,
856]:
857ar_nodes.append(node)
858self.assertEqual(len(ar_nodes), 1)
859reduce_op_str = ar_nodes[0].args[1]
860self.assertEqual(REDUCE_OP_TO_STR[reduce_op], reduce_op_str)
861return gm
862
863compiled = torch.compile(
864torch.distributed.all_reduce,
865backend=verify_rewrite,
866fullgraph=True,
867)
868inputs = (
869torch.ones(2, device=self.device),
870reduce_op,
871GroupMember.WORLD,
872)
873compiled(*inputs)
874
875@parametrize(
876"source",
877[
878"GroupMember.WORLD",
879"group.WORLD",
880"_get_default_group",
881],
882)
883def test_dynamo_get_world_group(self, source):
884def func(tensor):
885if source == "GroupMember.WORLD":
886group = torch.distributed.GroupMember.WORLD
887elif source == "group.WORLD":
888group = torch.distributed.group.WORLD
889else:
890assert source == "_get_default_group"
891group = torch.distributed.distributed_c10d._get_default_group()
892
893torch.distributed.all_reduce(
894tensor,
895group=group,
896)
897
898def verify(gm, _):
899ar_nodes = []
900for node in gm.graph.nodes:
901if node.target in [
902torch.ops.c10d_functional.all_reduce,
903torch.ops._c10d_functional.all_reduce,
904]:
905ar_nodes.append(node)
906self.assertEqual(len(ar_nodes), 1)
907return gm
908
909compiled = torch.compile(func, backend=verify, fullgraph=True)
910input = torch.ones(2, device=self.device)
911compiled(input)
912
913def test_dynamo_support_collective_op_with_async_op_False(self):
914def func(inp, out, *, pg):
915# user explicitly set the attribute `async_op` to False,
916# there should be no graph break
917torch.distributed.reduce_scatter_tensor(out, inp, group=pg, async_op=False)
918
919local_size = [4, 4]
920# single-proc test
921global_size = local_size
922
923inputs = torch.ones(local_size, device=self.device)
924outputs = torch.empty(global_size, device=self.device)
925correct_outputs = torch.empty(global_size, device=self.device)
926counter = CompileCounter()
927compiled = torch.compile(func, backend=counter)
928compiled(inputs, outputs, pg=GroupMember.WORLD)
929func(inputs, correct_outputs, pg=GroupMember.WORLD)
930assert counter.frame_count == 1
931assert counter.op_count == 3
932assert same(outputs, correct_outputs)
933
934def test_dynamo_graphbreaks_unsupported_async_op(self):
935def func(inp, out, *, pg):
936work = torch.distributed.reduce_scatter_tensor(
937out, inp, group=pg, async_op=True
938)
939work.wait()
940
941local_size = [4, 4]
942# single-proc test
943global_size = local_size
944
945inputs = torch.ones(local_size, device=self.device)
946outputs = torch.empty(global_size, device=self.device)
947correct_outputs = torch.empty(global_size, device=self.device)
948counter = CompileCounter()
949compiled = torch.compile(func, backend=counter)
950compiled(inputs, outputs, pg=GroupMember.WORLD)
951func(inputs, correct_outputs, pg=GroupMember.WORLD)
952assert counter.frame_count == 0
953assert counter.op_count == 0
954assert same(outputs, correct_outputs)
955
956def test_dynamo_pg_var(self):
957def func(inp, *, pg):
958x = pg.rank() + 1 % pg.size()
959return inp + x
960
961local_size = [4, 4]
962inputs = torch.ones(local_size, device=self.device)
963correct_outputs = torch.empty(local_size, device=self.device)
964counter = CompileCounter()
965compiled = torch.compile(func, backend=counter, fullgraph=True)
966outputs = compiled(inputs, pg=GroupMember.WORLD)
967correct_outputs = func(inputs, pg=GroupMember.WORLD)
968assert counter.frame_count == 1
969assert counter.op_count == 1
970assert same(outputs, correct_outputs)
971
972def test_dynamo_trace_reduce_scatter_tensor(self):
973def func(inp):
974ar = _functional_collectives.reduce_scatter_tensor(inp, "sum", 0, "0")
975return ar
976
977inputs = torch.ones(4, 4, device="cuda")
978counter = CompileCounter()
979compiled = torch.compile(func, backend=counter)
980out = compiled(inputs)
981correct = func(inputs)
982self.assertEqual(counter.frame_count, 1)
983
984# should test more precisely, but the 2 is supposed to be (reduce_scatter, wait)
985self.assertEqual(counter.op_count, 2)
986self.assertTrue(same(out, correct))
987
988def test_dynamo_trace_allgather_coalesced(self):
989def func(inp, *, tag, ranks, group_size):
990ar = torch.ops.c10d_functional.all_gather_into_tensor_coalesced(
991inp, tag, ranks, group_size
992)
993return ar
994
995inputs = [torch.ones(4, 4, device="cuda"), torch.ones(6, 6, device="cuda")]
996counter = CompileCounter()
997compiled = torch.compile(func, backend=counter)
998out = compiled(inputs, **self.get_world_trs())
999correct = func(inputs, **self.get_world_trs())
1000assert counter.frame_count == 1
1001assert counter.op_count == 3 # It generates 2 getattr to unpack the array
1002assert same(out, correct)
1003
1004def test_backwards(self):
1005"""
1006It's probably not that common to need backwards support for collectives.
1007
1008However, I wanted to at least see if it was possible to support it as a design goal.
1009"""
1010
1011def func(inp):
1012ar = _functional_collectives.all_reduce(inp, "sum", "0")
1013return ar
1014
1015input = torch.ones(4, 4, device="cuda", requires_grad=True)
1016# TODO implement backwards
1017with self.assertRaisesRegex(
1018RuntimeError,
1019"element 0 of tensors does not require grad and does not have a grad_fn",
1020):
1021compiled = torch.compile(
1022func, backend="aot_eager"
1023) # inductor bug with single-op allreduce graph
1024out = compiled(input)
1025out.sum().backward()
1026
1027correct_input = input.clone().detach().requires_grad_()
1028correct = func(correct_input)
1029correct.sum().backward()
1030self.assertTrue(same(out, correct))
1031self.assertTrue(same(input.grad, correct_input.grad))
1032
1033def test_meta(self):
1034x = torch.rand((2, 3, 4), device="meta")
1035out = torch.ops.c10d_functional.all_reduce(x, "sum", **self.get_world_trs())
1036self.assertEqual(x.size(), out.size())
1037
1038@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
1039@torch._inductor.config.patch({"debug": True, "triton.descriptive_names": False})
1040def test_inductor_all_gather_coalesced(self):
1041"""
1042make sure that an intermediate that's going to be reuse isn't mutated unless copied
1043"""
1044
1045def func(inp, *, tag, ranks, group_size):
1046x = inp + 1
1047tensor_list = torch.ops.c10d_functional.all_gather_into_tensor_coalesced(
1048[x, inp], tag, ranks, group_size
1049)
1050y = x + 2
1051ar0 = torch.ops.c10d_functional.wait_tensor(tensor_list[0])
1052ar1 = torch.ops.c10d_functional.wait_tensor(tensor_list[1])
1053# ensure other is not incorrectly aliasing ar's buffer
1054other = torch.ones_like(inp) + 22
1055return ar0, y, other, ar1
1056
1057inputs = torch.ones(4, 4, device="cuda")
1058
1059compiled = torch.compile(func)
1060code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
1061# NOTE: Make sure we are not unneccessarily copying the outputs of
1062# wait_tensors before they are returned from the graph.
1063(
1064FileCheck()
1065.check("buf0 = empty_strided")
1066.check("buf6 = empty_strided")
1067.check(".run(arg0_1, buf0, buf6, 16")
1068.check(
1069"buf1 = torch.ops._c10d_functional.all_gather_into_tensor_coalesced.default([buf0, arg0_1]"
1070)
1071.check("buf2 = buf1[0]")
1072.check("buf3 = buf1[1]")
1073.check("torch.ops._c10d_functional.wait_tensor.default(buf2")
1074.check("buf7 = buf0; del buf0 # reuse")
1075.check(".run(buf7, 16")
1076.check("torch.ops._c10d_functional.wait_tensor.default(buf3")
1077.check("return (buf2, buf6, buf7, buf3")
1078.run(code)
1079)
1080out = compiled(inputs, **self.get_world_trs())
1081correct = func(inputs, **self.get_world_trs())
1082assert same(out, correct), f"{out} va {correct}"
1083
1084@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
1085@torch._inductor.config.patch({"debug": True, "triton.descriptive_names": False})
1086def test_inductor_reduce_scatter_coalesced(self):
1087"""
1088make sure that an intermediate that's going to be reuse isn't mutated unless copied
1089"""
1090
1091def func(inp, *, tag, ranks, group_size):
1092x = inp + 1
1093tensor_list = torch.ops.c10d_functional.reduce_scatter_tensor_coalesced(
1094[x, inp], "sum", tag, ranks, group_size
1095)
1096y = x + 2
1097ar0 = torch.ops.c10d_functional.wait_tensor(tensor_list[0])
1098ar1 = torch.ops.c10d_functional.wait_tensor(tensor_list[1])
1099# ensure other is not incorrectly aliasing ar's buffer
1100other = torch.ones_like(inp) + 22
1101return ar0, y, other, ar1
1102
1103inputs = torch.ones(4, 4, device="cuda")
1104
1105compiled = torch.compile(func)
1106code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
1107# NOTE: The first return value should be the output of the first wait_tensor.
1108# We want to make sure no unneccessary copy is made.
1109(
1110FileCheck()
1111.check("buf0 = empty_strided")
1112.check("buf6 = empty_strided")
1113.check(".run(arg0_1, buf0, buf6, 16")
1114.check(
1115"buf1 = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced.default([buf0, arg0_1]"
1116)
1117.check("buf2 = buf1[0]")
1118.check("buf3 = buf1[1]")
1119.check("torch.ops._c10d_functional.wait_tensor.default(buf2")
1120.check("buf7 = buf0; del buf0 # reuse")
1121.check(".run(buf7, 16")
1122.check("torch.ops._c10d_functional.wait_tensor.default(buf3")
1123.check("return (buf2, buf6, buf7, buf3")
1124.run(code)
1125)
1126out = compiled(inputs, **self.get_world_trs())
1127correct = func(inputs, **self.get_world_trs())
1128assert same(out, correct), f"{out} va {correct}"
1129
1130
1131if __name__ == "__main__":
1132from torch._dynamo.test_case import run_tests
1133
1134run_tests()
1135