pytorch

Форк
0
/
test_inductor_collectives.py 
1134 строки · 43.5 Кб
1
# Owner(s): ["module: dynamo"]
2
import functools
3
import unittest
4
from unittest.mock import patch
5

6
import torch
7
import torch._dynamo
8
import torch._dynamo.logging
9
import torch._dynamo.test_case
10

11
# for some reason importing functional collectives after dynamo breaks collectives handling!
12
import torch.distributed._functional_collectives as _functional_collectives
13
from torch._C import FileCheck
14
from torch._dynamo.testing import CompileCounter
15
from torch._dynamo.utils import same
16
from torch._inductor.compile_fx import compile_fx as inductor_compile_fx
17
from torch._inductor.utils import run_and_get_triton_code
18
from torch.distributed.distributed_c10d import GroupMember
19
from torch.fx.experimental.proxy_tensor import make_fx
20
from torch.testing._internal.common_distributed import (
21
    _dynamo_dist_per_rank_init,
22
    DynamoDistributedMultiProcTestCase,
23
    DynamoDistributedSingleProcTestCase,
24
    requires_nccl,
25
    skip_if_lt_x_gpu,
26
)
27
from torch.testing._internal.common_utils import (
28
    instantiate_parametrized_tests,
29
    parametrize,
30
    requires_cuda,
31
)
32
from torch.utils._triton import has_triton
33

34

35
def _tolist_with_constrain_as_size(tensor):
36
    lst = tensor.tolist()
37
    for elem in lst:
38
        torch._check_is_size(elem)
39
    return lst
40

41

42
@requires_nccl()
43
class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
44
    """
45
    Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under
46
    """
47

48
    def get_world_trs(self):
49
        return {
50
            "tag": "",
51
            "ranks": list(range(self.world_size)),
52
            "group_size": self.world_size,
53
        }
54

55
    @property
56
    def world_size(self) -> int:
57
        # hack: no matter whether we have 2 or 3 or 4 gpus, just run on 2
58
        # works around issue with skipif<2 and workers with unpredictable #s gpu
59
        return 2
60

61
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
62
    @skip_if_lt_x_gpu(2)
63
    def test_broadcast_inductor(self):
64
        """
65
        Testing if broadcast works correctly when using inductor
66
        """
67

68
        def example(tensor, src, *, tag, ranks, group_size):
69
            res = torch.ops.c10d_functional.broadcast(
70
                tensor, src, tag, ranks, group_size
71
            )
72
            res = torch.ops.c10d_functional.wait_tensor(res)
73
            return res
74

75
        def compile(func, example_inputs):
76
            graph = make_fx(func)(*example_inputs)
77
            return inductor_compile_fx(graph, example_inputs)
78

79
        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
80
            example = functools.partial(
81
                example,
82
                **self.get_world_trs(),
83
            )
84
            t = torch.randn(4, 4, device="cuda")
85
            inputs = (t if self.rank == 0 else torch.zeros(4, 4, device="cuda"), 0)
86
            eager_out = example(*inputs)
87
            self.assertTrue(same(t, eager_out))
88

89
            compiled_func = compile(example, inputs)
90
            compiled_out = compiled_func(*inputs)
91
            self.assertTrue(same(eager_out, compiled_out))
92

93
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
94
    @skip_if_lt_x_gpu(2)
95
    def test_allreduce_inductor(self):
96
        """
97
        This is matmul/cat/allreduce is a pattern we aim to optimize.
98
        """
99

100
        def matmul_cat_col(a, b, c, d, e, f, *, tag, ranks, group_size):
101
            x = torch.matmul(a, b)
102
            y = torch.matmul(c, d)
103
            z = torch.cat((x, y))
104
            ar = torch.ops.c10d_functional.all_reduce(z, "sum", tag, ranks, group_size)
105
            g = torch.matmul(e, f)
106
            ar = torch.ops.c10d_functional.wait_tensor(ar)
107
            out = torch.add(ar, g.repeat(2, 1))
108
            return (out,)
109

110
        def compile(func, example_inputs):
111
            graph = make_fx(func)(*example_inputs)
112
            return inductor_compile_fx(graph, example_inputs)
113

114
        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
115
            matmul_cat_col = functools.partial(
116
                matmul_cat_col,
117
                **self.get_world_trs(),
118
            )
119
            inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 6
120

121
            eager_out = matmul_cat_col(*inputs)
122
            compiled_matmul_cat_col = compile(matmul_cat_col, inputs)
123
            inductor_out = compiled_matmul_cat_col(*inputs)
124
            self.assertTrue(same(eager_out, inductor_out, tol=0.001))
125

126
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
127
    @skip_if_lt_x_gpu(2)
128
    def test_allreduce_inductor_cudagraph_trees(self):
129
        """
130
        Tests whether cudagraph trees support all_reduce from nccl
131
        """
132
        import torch.distributed as dist
133

134
        # dist.all_reduce is an inplace op in eager mode but a functionanlized op in compiled mode.
135
        # so we define eager_func and func separately for the same semantic.
136
        def eager_func(x):
137
            y = x * x
138
            dist.all_reduce(y, op=dist.ReduceOp.SUM)
139
            x = torch.nn.functional.silu(x)
140
            return x * y
141

142
        def func(x):
143
            y = x * x
144
            y = dist.all_reduce(y, op=dist.ReduceOp.SUM)
145
            x = torch.nn.functional.silu(x)
146
            return x * y
147

148
        options = {
149
            "triton.cudagraphs": True,
150
            "triton.cudagraph_trees": True,
151
        }
152

153
        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
154
            compiled_func = torch.compile(
155
                func, backend="inductor", fullgraph=True, options=options, dynamic=None
156
            )
157

158
            for nelem in [1024, 2048, 4096]:
159
                x = torch.randn(nelem, device="cuda", dtype=torch.bfloat16)
160
                golden_out = eager_func(x)
161

162
                for _ in range(3):
163
                    compiled_out = compiled_func(x)
164
                    self.assertEqual(golden_out, compiled_out)
165

166
    def test_c10d_functional_tagged_pt2_compliant(self):
167
        op = torch.ops._c10d_functional.all_reduce.default
168
        self.assertIn(torch.Tag.pt2_compliant_tag, op.tags)
169
        op = torch.ops.c10d_functional.all_reduce.default
170
        self.assertIn(torch.Tag.pt2_compliant_tag, op.tags)
171

172
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
173
    @skip_if_lt_x_gpu(2)
174
    def test_eager_allreduce_inductor_wait(self):
175
        def eager_func(a, b, c, d, *, tag, ranks, group_size):
176
            x = torch.matmul(a, b)
177
            y = torch.matmul(c, d)
178
            z = torch.cat((x, y))
179
            ar = torch.ops.c10d_functional.all_reduce(z, "sum", tag, ranks, group_size)
180
            return ar
181

182
        def inductor_func(ar, e, f):
183
            g = torch.matmul(e, f)
184
            ar = torch.ops.c10d_functional.wait_tensor(ar)
185
            out = torch.add(ar, g.repeat(2, 1))
186
            return (out,)
187

188
        def compile(func, example_inputs):
189
            graph = make_fx(func)(*example_inputs)
190
            return inductor_compile_fx(graph, example_inputs)
191

192
        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
193
            eager_func = functools.partial(
194
                eager_func,
195
                **self.get_world_trs(),
196
            )
197
            eager_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 4
198
            inductor_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2
199

200
            eager_out = inductor_func(eager_func(*eager_inputs), *inductor_inputs)
201
            compiled_inductor_func = compile(
202
                inductor_func, [eager_func(*eager_inputs)] + list(inductor_inputs)
203
            )
204
            inductor_out = compiled_inductor_func(
205
                eager_func(*eager_inputs), *inductor_inputs
206
            )
207
            print(f"eager_out, {eager_out}")
208
            print(f"inductor_out, {inductor_out}")
209
            self.assertTrue(same(eager_out, inductor_out, tol=0.001))
210

211
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
212
    @skip_if_lt_x_gpu(2)
213
    def test_inductor_allreduce_eager_wait(self):
214
        def inductor_func(a, b, c, d, *, tag, ranks, group_size):
215
            x = torch.matmul(a, b)
216
            y = torch.matmul(c, d)
217
            z = torch.cat((x, y))
218
            ar = torch.ops.c10d_functional.all_reduce(z, "sum", tag, ranks, group_size)
219
            return ar
220

221
        def eager_func(ar, e, f):
222
            g = torch.matmul(e, f)
223
            ar = torch.ops.c10d_functional.wait_tensor(ar)
224
            out = torch.add(ar, g.repeat(2, 1))
225
            return (out,)
226

227
        def compile(func, example_inputs):
228
            graph = make_fx(func)(*example_inputs)
229
            return inductor_compile_fx(graph, example_inputs)
230

231
        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
232
            inductor_func = functools.partial(
233
                inductor_func,
234
                **self.get_world_trs(),
235
            )
236
            inductor_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 4
237
            eager_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2
238

239
            eager_out = eager_func(inductor_func(*inductor_inputs), *eager_inputs)
240
            compiled_inductor_func = compile(inductor_func, inductor_inputs)
241
            inductor_out = eager_func(
242
                compiled_inductor_func(*inductor_inputs), *eager_inputs
243
            )
244
            self.assertTrue(same(eager_out, inductor_out, tol=0.001))
245

246
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
247
    @skip_if_lt_x_gpu(2)
248
    @patch.object(torch._inductor.config, "allow_buffer_reuse", True)
249
    def test_allreduce_input_buffer_reuse(self):
250
        def func(a, *, tag, ranks, group_size):
251
            ar = _functional_collectives.all_reduce(a, "sum", ranks, tag)
252
            c = torch.relu(a)
253
            d = torch.matmul(c, c)
254
            e = d + ar
255
            return (e,)
256

257
        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
258
            inputs = torch.ones(4, 4, device="cuda") + self.rank
259
            compiled = torch.compile(func)
260
            out = compiled(inputs, **self.get_world_trs())
261
            correct = func(inputs, **self.get_world_trs())
262
            self.assertTrue(same(out, correct))
263

264
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
265
    @skip_if_lt_x_gpu(2)
266
    def test_permute_tensor(self):
267
        def func(tensor, src_dst_pairs, *, tag, ranks, group_size):
268
            return _functional_collectives.permute_tensor(
269
                tensor, src_dst_pairs, ranks, tag
270
            )
271

272
        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
273
            inputs = (
274
                # rank0: [0., 1.], rank1: [2., 3.]
275
                torch.arange(2, dtype=torch.float32, device="cuda") + 2 * self.rank,
276
                [1, 0],
277
            )
278
            compiled = torch.compile(func)
279
            out = compiled(*inputs, **self.get_world_trs())
280
            correct = func(*inputs, **self.get_world_trs())
281
            self.assertTrue(same(out, correct))
282

283
            # rank0: [2., 3.], rank1: [0., 1.]
284
            expected = torch.arange(2, dtype=torch.float32, device="cuda") + 2 * (
285
                (self.rank - 1 + self.world_size) % self.world_size
286
            )
287
            self.assertEqual(out, expected)
288
            self.assertEqual(correct, expected)
289

290
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
291
    @skip_if_lt_x_gpu(2)
292
    @patch.object(torch._inductor.config, "allow_buffer_reuse", True)
293
    def test_allgather_output_buffer_reuse(self):
294
        class Model(torch.nn.Module):
295
            def __init__(self, *args, **kwargs) -> None:
296
                super().__init__(*args, **kwargs)
297
                self.emb = torch.nn.Embedding(4, 4)
298

299
            def forward(self, x, world_size, tag, ranks, group_size):
300
                y = self.emb(x)
301
                last_dim = y.dim() - 1
302
                res = _functional_collectives.all_gather_tensor(y, 0, ranks, tag)
303
                out = torch.cat(torch.chunk(res, world_size, dim=0), dim=last_dim)
304
                return out
305

306
        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
307
            model = Model().cuda()
308
            model_compiled = torch.compile(model)
309
            inp = torch.tensor([[2, 1, 3, 0]], dtype=torch.long, device="cuda")
310
            out = model_compiled(inp, self.world_size, **self.get_world_trs())
311
            correct = model(inp, self.world_size, **self.get_world_trs())
312
            self.assertTrue(same(out, correct))
313

314
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
315
    @skip_if_lt_x_gpu(2)
316
    def test_allgather_contiguous_input(self):
317
        class Model(torch.nn.Module):
318
            def __init__(self, *args, **kwargs) -> None:
319
                super().__init__(*args, **kwargs)
320
                self.emb = torch.nn.Embedding(4, 4)
321

322
            def forward(self, x, world_size, tag, ranks, group_size):
323
                y = self.emb(x)
324
                last_dim = y.dim() - 1
325
                y = y.transpose_(0, last_dim).contiguous()
326
                res = _functional_collectives.all_gather_tensor(y, 0, ranks, tag)
327
                out = y.transpose_(0, last_dim).contiguous()
328
                return out
329

330
        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
331
            model = Model().cuda()
332
            model_compiled = torch.compile(model)
333
            inp = torch.tensor([[2, 1, 3, 0]], dtype=torch.long, device="cuda")
334
            out = model_compiled(inp, self.world_size, **self.get_world_trs())
335
            correct = model(inp, self.world_size, **self.get_world_trs())
336
            self.assertTrue(same(out, correct))
337

338
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
339
    @skip_if_lt_x_gpu(2)
340
    def test_allgather_into_tensor_inductor(self):
341
        """
342
        This is matmul/cat/allreduce is a pattern we aim to optimize.
343
        """
344

345
        def example(a, b, *, tag, ranks, group_size):
346
            c = torch.matmul(a, b)
347
            ag = torch.ops.c10d_functional.all_gather_into_tensor(
348
                c, tag, ranks, group_size
349
            )
350
            ag = torch.ops.c10d_functional.wait_tensor(ag)
351
            return (ag,)
352

353
        def compile(func, example_inputs):
354
            graph = make_fx(func)(*example_inputs)
355
            return inductor_compile_fx(graph, example_inputs)
356

357
        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
358
            example = functools.partial(
359
                example,
360
                **self.get_world_trs(),
361
            )
362
            inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2
363

364
            eager_out = example(*inputs)
365
            compiled_matmul_cat_col = compile(example, inputs)
366
            inductor_out = compiled_matmul_cat_col(*inputs)
367
            self.assertTrue(same(eager_out, inductor_out, tol=0.001))
368

369
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
370
    @skip_if_lt_x_gpu(2)
371
    def test_reduce_scatter_tensor_inductor(self):
372
        def example(a, b, *, tag, ranks, group_size):
373
            c = torch.matmul(a, b)
374
            ag = torch.ops.c10d_functional.reduce_scatter_tensor(
375
                c, "sum", tag, ranks, group_size
376
            )
377
            ag = torch.ops.c10d_functional.wait_tensor(ag)
378
            return (ag,)
379

380
        def compile(func, example_inputs):
381
            graph = make_fx(func)(*example_inputs)
382
            return inductor_compile_fx(graph, example_inputs)
383

384
        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
385
            example = functools.partial(
386
                example,
387
                **self.get_world_trs(),
388
            )
389
            inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2
390

391
            eager_out = example(*inputs)
392
            compiled_fn = compile(example, inputs)
393
            inductor_out = compiled_fn(*inputs)
394
            self.assertTrue(same(eager_out, inductor_out, tol=0.001))
395

396
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
397
    @skip_if_lt_x_gpu(2)
398
    @patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
399
    def test_all_to_all_single_inductor(self):
400
        def example(
401
            inp,
402
            input_split_sizes_tensor,
403
            output_split_sizes_tensor,
404
            *,
405
            tag,
406
            ranks,
407
            group_size,
408
        ):
409
            input_split_sizes = _tolist_with_constrain_as_size(input_split_sizes_tensor)
410
            output_split_sizes = _tolist_with_constrain_as_size(
411
                output_split_sizes_tensor
412
            )
413
            a2a = torch.ops.c10d_functional.all_to_all_single(
414
                inp,
415
                output_split_sizes,
416
                input_split_sizes,
417
                tag,
418
                ranks,
419
                group_size,
420
            )
421
            a2a = torch.ops.c10d_functional.wait_tensor(a2a)
422
            out = a2a / a2a.sum(dim=0)
423
            return out
424

425
        with _dynamo_dist_per_rank_init(
426
            self.rank, self.world_size
427
        ), torch._dynamo.config.patch(
428
            dynamic_shapes=True,
429
            capture_dynamic_output_shape_ops=True,
430
            capture_scalar_outputs=True,
431
        ):
432
            row = self.world_size * (self.rank + 1) * (self.world_size + 1) / 2
433
            input_split_sizes_tensor = torch.tensor(
434
                [(i + 1) * (self.rank + 1) for i in range(self.world_size)],
435
                dtype=torch.int64,
436
            )
437
            output_split_sizes_tensor = torch.tensor(
438
                [(i + 1) * (self.rank + 1) for i in range(self.world_size)],
439
                dtype=torch.int64,
440
            )
441
            inputs = (
442
                torch.ones(int(row), 5, device="cuda") * (self.rank + 1),
443
                input_split_sizes_tensor,
444
                output_split_sizes_tensor,
445
            )
446
            trs = self.get_world_trs()
447

448
            compiled_fn = torch.compile(example, fullgraph=True, dynamic=True)
449
            code = run_and_get_triton_code(compiled_fn, *inputs, **trs)
450
            (
451
                FileCheck()
452
                .check_regex(
453
                    "torch.ops._c10d_functional.all_to_all_single.default\\("
454
                    "arg\\d+_\\d+, "
455
                    "\\[u\\d+, u\\d+\\], "
456
                    "\\[u\\d+, u\\d+\\]"
457
                )
458
                .run(code)
459
            )
460

461
            eager_out = example(*inputs, **trs)
462
            inductor_out = compiled_fn(*inputs, **trs)
463
            self.assertTrue(same(eager_out, inductor_out, tol=0.001))
464

465
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
466
    @skip_if_lt_x_gpu(2)
467
    def test_all_to_all_single_inductor_split_sizes_none(self):
468
        def example(inp, *, tag, ranks, group_size):
469
            a2a = torch.ops.c10d_functional.all_to_all_single(
470
                inp,
471
                None,
472
                None,
473
                tag,
474
                ranks,
475
                group_size,
476
            )
477
            a2a = torch.ops.c10d_functional.wait_tensor(a2a)
478
            out = a2a / a2a.sum(dim=0)
479
            return out
480

481
        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
482
            inputs = (
483
                torch.ones(self.world_size, self.world_size, device="cuda")
484
                * (self.rank + 1),
485
            )
486
            trs = self.get_world_trs()
487

488
            compiled_fn = torch.compile(example, fullgraph=True, dynamic=True)
489
            code = run_and_get_triton_code(compiled_fn, *inputs, **trs)
490
            (
491
                FileCheck()
492
                .check_regex(
493
                    "torch.ops._c10d_functional.all_to_all_single.default\\("
494
                    "arg\\d+_\\d+, "
495
                    "\\[\\(s\\d+ // \\d\\), \\(s\\d+ // \\d\\)\\], "
496
                    "\\[\\(s\\d+ // \\d\\), \\(s\\d+ // \\d\\)\\]"
497
                )
498
                .run(code)
499
            )
500

501
            eager_out = example(*inputs, **trs)
502
            inductor_out = compiled_fn(*inputs, **trs)
503
            self.assertTrue(same(eager_out, inductor_out, tol=0.001))
504

505

506
@instantiate_parametrized_tests
507
@requires_nccl()
508
@requires_cuda
509
class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
510
    """
511
    Prefer single-proc test runner for basic tests as it is easier to work with.
512
    """
513

514
    def get_world_trs(self, world_size=1):
515
        return {
516
            "tag": "",
517
            "ranks": list(range(world_size)),
518
            "group_size": world_size,
519
        }
520

521
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
522
    @torch._inductor.config.patch(debug=True)
523
    def test_inductor_single_op(self):
524
        def func(inp, *, tag, ranks, group_size):
525
            ar = torch.ops.c10d_functional.all_reduce(
526
                inp, "sum", tag, ranks, group_size
527
            )
528
            ar = torch.ops.c10d_functional.wait_tensor(ar)
529
            return ar
530

531
        inputs = torch.ones(4, 4, device="cuda")
532

533
        compiled = torch.compile(func)
534
        out = compiled(inputs, **self.get_world_trs())
535
        code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
536
        # NOTE: Make sure we are not unneccessarily copying the outputs of
537
        # wait_tensors before they are returned from the graph.
538
        (
539
            FileCheck()
540
            .check("buf0 = empty_strided")
541
            .check(".run(arg0_1, buf0, 16")
542
            .check("torch.ops._c10d_functional.all_reduce_.default(buf0")
543
            .check("torch.ops._c10d_functional.wait_tensor.default(buf0")
544
            .check("return (buf0")
545
            .run(code)
546
        )
547
        correct = func(inputs, **self.get_world_trs())
548
        self.assertTrue(same(out, correct))
549

550
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
551
    @torch._inductor.config.patch(debug=True)
552
    def test_inductor_steal_buffer(self):
553
        """
554
        it's ok and optimal if inductor allreduce mutates the buffer of an intermediate
555
        that isn't going to be used again
556
        """
557

558
        def func(inp, *, tag, ranks, group_size):
559
            x = inp + 1
560
            ar = torch.ops.c10d_functional.all_reduce(x, "sum", tag, ranks, group_size)
561
            ar = torch.ops.c10d_functional.wait_tensor(ar)
562
            # ensure other is not incorrectly aliasing ar's buffer
563
            other = torch.ones_like(inp) + 22
564
            return ar, other
565

566
        inputs = torch.ones(4, 4, device="cuda")
567

568
        compiled = torch.compile(func)
569
        code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
570
        (
571
            FileCheck()
572
            .check("buf0 = empty_strided")
573
            .check(".run(arg0_1, buf0")
574
            .check("torch.ops._c10d_functional.all_reduce_.default(buf0")
575
            .check("torch.ops._c10d_functional.wait_tensor.default(buf0")
576
            .check("buf5 = empty_strided")
577
            .check(".run(buf5, 16")
578
            .check("return (buf0, buf5")
579
            .run(code)
580
        )
581
        out = compiled(inputs, **self.get_world_trs())
582
        correct = func(inputs, **self.get_world_trs())
583
        self.assertTrue(same(out, correct))
584

585
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
586
    @torch._inductor.config.patch({"debug": True, "triton.descriptive_names": False})
587
    def test_inductor_doesnt_mutate_shared(self):
588
        """
589
        make sure that an intermediate that's going to be reuse isn't mutated unless copied
590
        """
591

592
        def func(inp, *, tag, ranks, group_size):
593
            x = inp + 1
594
            ar = torch.ops.c10d_functional.all_reduce(x, "sum", tag, ranks, group_size)
595
            y = x + 2
596
            ar = torch.ops.c10d_functional.wait_tensor(ar)
597
            # ensure other is not incorrectly aliasing ar's buffer
598
            other = torch.ones_like(inp) + 22
599
            return ar, y, other
600

601
        inputs = torch.ones(4, 4, device="cuda")
602

603
        compiled = torch.compile(func)
604
        code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
605
        # NOTE: Make sure we are not unneccessarily copying the outputs of
606
        # wait_tensors before they are returned from the graph.
607
        (
608
            FileCheck()
609
            .check("buf0 = empty_strided")
610
            .check("buf5 = empty_strided")
611
            .check(".run(arg0_1, buf0, buf5, 16")
612
            .check("torch.ops._c10d_functional.all_reduce_.default(buf0")
613
            .check("torch.ops._c10d_functional.wait_tensor.default(buf0")
614
            .check("buf6 = empty_strided")
615
            .check(".run(buf6, 16")
616
            .check("return (buf0, buf5, buf6")
617
            .run(code)
618
        )
619
        out = compiled(inputs, **self.get_world_trs())
620
        correct = func(inputs, **self.get_world_trs())
621
        self.assertTrue(same(out, correct))
622

623
    def test_dynamo_trace_allreduce(self):
624
        def func(inp):
625
            ar = _functional_collectives.all_reduce(inp, "sum", "0")
626
            return ar
627

628
        inputs = torch.ones(4, 4, device="cuda")
629
        counter = CompileCounter()
630
        compiled = torch.compile(func, backend=counter)
631
        out = compiled(inputs)
632
        correct = func(inputs)
633
        self.assertEqual(counter.frame_count, 1)
634

635
        # should test more precisely, but the 2 is supposed to be (all_reduce, wait)
636
        self.assertEqual(counter.op_count, 2)
637
        self.assertTrue(same(out, correct))
638

639
    def test_dynamo_trace_all_gather_tensor(self):
640
        def func(inp):
641
            ar = _functional_collectives.all_gather_tensor(inp, 0, "0")
642
            return ar
643

644
        inputs = torch.ones(4, 4, device="cuda")
645
        counter = CompileCounter()
646
        compiled = torch.compile(func, backend=counter)
647
        out = compiled(inputs)
648
        correct = func(inputs)
649
        self.assertEqual(counter.frame_count, 1)
650

651
        # should test more precisely, but the 2 is supposed to be (all_gather, wait)
652
        self.assertEqual(counter.op_count, 2)
653
        self.assertTrue(same(out, correct))
654

655
    def test_dynamo_trace_all_gather_tensor_pg(self):
656
        def func(inp, *, pg):
657
            ar = _functional_collectives.all_gather_tensor(inp, 0, pg)
658
            return ar
659

660
        inputs = torch.ones(4, 4, device=self.device)
661
        counter = CompileCounter()
662
        compiled = torch.compile(func, backend=counter, fullgraph=True)
663
        out = compiled(inputs, pg=GroupMember.WORLD)
664
        correct = func(inputs, pg=GroupMember.WORLD)
665
        self.assertEqual(counter.frame_count, 1)
666

667
        # should test more precisely, but the 2 is supposed to be (all_gather, wait)
668
        self.assertEqual(counter.op_count, 2)
669
        self.assertTrue(same(out, correct))
670

671
    def test_dynamo_rewrite_dist_all_gather(self):
672
        def func(inp, out, *, pg):
673
            torch.distributed.all_gather_into_tensor(
674
                out,
675
                inp,
676
                pg,
677
            )
678

679
        local_size = [4, 4]
680
        # single-proc test
681
        global_size = local_size
682

683
        inputs = torch.ones(local_size, device=self.device)
684
        outputs = torch.empty(global_size, device=self.device)
685
        correct_outputs = torch.empty(global_size, device=self.device)
686
        counter = CompileCounter()
687
        compiled = torch.compile(func, backend=counter, fullgraph=True)
688
        compiled(inputs, outputs, pg=GroupMember.WORLD)
689
        func(inputs, correct_outputs, pg=GroupMember.WORLD)
690
        assert counter.frame_count == 1
691

692
        # should test more precisely, but the 3 is supposed to be (all_gather, wait, copy_)
693
        assert counter.op_count == 3
694
        assert same(outputs, correct_outputs)
695

696
    def test_dynamo_rewrite_dist_all_gather_list(self):
697
        def func(inp, out, *, pg):
698
            torch.distributed.all_gather(
699
                out,
700
                inp,
701
                pg,
702
            )
703

704
        local_size = [4, 4]
705
        # single-proc test
706
        global_size = local_size
707

708
        inputs = torch.ones(local_size, device=self.device)
709
        outputs = [torch.empty(global_size, device=self.device)]
710
        correct_outputs = [torch.empty(global_size, device=self.device)]
711
        counter = CompileCounter()
712
        compiled = torch.compile(func, backend=counter, fullgraph=True)
713
        compiled(inputs, outputs, pg=GroupMember.WORLD)
714
        func(inputs, correct_outputs, pg=GroupMember.WORLD)
715
        assert counter.frame_count == 1
716
        assert same(outputs, correct_outputs)
717

718
    def test_dynamo_rewrite_dist_all_gather_args_match(self):
719
        # Duplicated most of the structure from test_dynamo_rewrite_dist_all_gather
720
        # except uses kwargs to ensure rewrite has matching arg names
721
        def func(inp, out, *, pg):
722
            torch.distributed.all_gather_into_tensor(
723
                output_tensor=out,
724
                input_tensor=inp,
725
                group=pg,
726
                async_op=False,
727
            )
728

729
        local_size = [4, 4]
730
        # single-proc test
731
        global_size = local_size
732

733
        inputs = torch.ones(local_size, device=self.device)
734
        outputs = torch.empty(global_size, device=self.device)
735
        correct_outputs = torch.empty(global_size, device=self.device)
736
        counter = CompileCounter()
737
        compiled = torch.compile(func, backend=counter, fullgraph=True)
738
        compiled(inputs, outputs, pg=GroupMember.WORLD)
739
        func(inputs, correct_outputs, pg=GroupMember.WORLD)
740
        assert counter.frame_count == 1
741

742
        # should test more precisely, but the 3 is supposed to be (all_gather, wait, copy_)
743
        assert counter.op_count == 3
744
        assert same(outputs, correct_outputs)
745

746
    def test_dynamo_rewrite_dist_reduce_scatter(self):
747
        def func(inp, out, *, pg):
748
            torch.distributed.reduce_scatter_tensor(
749
                out,
750
                inp,
751
                group=pg,
752
            )
753

754
        local_size = [4, 4]
755
        # single-proc test
756
        global_size = local_size
757

758
        inputs = torch.ones(local_size, device=self.device)
759
        outputs = torch.empty(global_size, device=self.device)
760
        correct_outputs = torch.empty(global_size, device=self.device)
761
        counter = CompileCounter()
762
        compiled = torch.compile(func, backend=counter, fullgraph=True)
763
        compiled(inputs, outputs, pg=GroupMember.WORLD)
764
        func(inputs, correct_outputs, pg=GroupMember.WORLD)
765
        assert counter.frame_count == 1
766

767
        # should test more precisely, but the 3 is supposed to be (reduce_scatter, wait, copy_)
768
        assert counter.op_count == 3
769
        assert same(outputs, correct_outputs)
770

771
    @parametrize(
772
        "pg_mode",
773
        [
774
            "positional",
775
            "positional_none",
776
            "kwargs",
777
            "kwargs_none",
778
            "unspecified",
779
        ],
780
    )
781
    def test_dynamo_rewrite_dist_allreduce(self, pg_mode):
782
        def func(tensor, *args, **kwargs):
783
            torch.distributed.all_reduce(
784
                tensor,
785
                *args,
786
                **kwargs,
787
            )
788

789
        counter = CompileCounter()
790
        compiled = torch.compile(func, backend=counter, fullgraph=True)
791

792
        args = []
793
        kwargs = {}
794

795
        if pg_mode == "positional":
796
            args.append(torch.distributed.ReduceOp.MAX)
797
            args.append(GroupMember.WORLD)
798
        elif pg_mode == "positional_none":
799
            args.append(torch.distributed.ReduceOp.MAX)
800
            args.append(None)
801
        elif pg_mode == "kwargs":
802
            kwargs["group"] = GroupMember.WORLD
803
        elif pg_mode == "kwargs_none":
804
            kwargs["group"] = None
805
        else:
806
            assert pg_mode == "unspecified"
807

808
        inputs_compiled = torch.ones(2, device=self.device)
809
        inputs_eager = torch.ones(2, device=self.device)
810

811
        compiled(inputs_compiled, *args, **kwargs)
812
        func(inputs_eager, *args, **kwargs)
813

814
        assert counter.frame_count == 1
815
        # should test more precisely, but the 3 is supposed to be (all_reduce, wait, copy_)
816
        assert counter.op_count == 3
817
        assert same(inputs_compiled, inputs_eager)
818

819
    def test_dynamo_rewrite_dist_all_to_all_single(self):
820
        def func(output, input, pg):
821
            torch.distributed.all_to_all_single(output, input, group=pg)
822

823
        counter = CompileCounter()
824
        compiled = torch.compile(func, backend=counter, fullgraph=True)
825

826
        input_compiled = torch.ones(2, device=self.device)
827
        input_eager = torch.ones(2, device=self.device)
828
        output_compiled = torch.empty(2, device=self.device)
829
        output_eager = torch.empty(2, device=self.device)
830

831
        compiled(output_compiled, input_compiled, GroupMember.WORLD)
832
        func(output_eager, input_eager, GroupMember.WORLD)
833

834
        assert counter.frame_count == 1
835
        assert same(output_compiled, output_eager)
836

837
    @parametrize(
838
        "reduce_op",
839
        [
840
            torch.distributed.ReduceOp.SUM,
841
            torch.distributed.ReduceOp.AVG,
842
            torch.distributed.ReduceOp.PRODUCT,
843
            torch.distributed.ReduceOp.MIN,
844
            torch.distributed.ReduceOp.MAX,
845
        ],
846
    )
847
    def test_dynamo_rewrite_dist_allreduce_reduce_op(self, reduce_op):
848
        from torch.distributed._functional_collectives import REDUCE_OP_TO_STR
849

850
        def verify_rewrite(gm, _):
851
            ar_nodes = []
852
            for node in gm.graph.nodes:
853
                if node.target in [
854
                    torch.ops.c10d_functional.all_reduce,
855
                    torch.ops._c10d_functional.all_reduce,
856
                ]:
857
                    ar_nodes.append(node)
858
            self.assertEqual(len(ar_nodes), 1)
859
            reduce_op_str = ar_nodes[0].args[1]
860
            self.assertEqual(REDUCE_OP_TO_STR[reduce_op], reduce_op_str)
861
            return gm
862

863
        compiled = torch.compile(
864
            torch.distributed.all_reduce,
865
            backend=verify_rewrite,
866
            fullgraph=True,
867
        )
868
        inputs = (
869
            torch.ones(2, device=self.device),
870
            reduce_op,
871
            GroupMember.WORLD,
872
        )
873
        compiled(*inputs)
874

875
    @parametrize(
876
        "source",
877
        [
878
            "GroupMember.WORLD",
879
            "group.WORLD",
880
            "_get_default_group",
881
        ],
882
    )
883
    def test_dynamo_get_world_group(self, source):
884
        def func(tensor):
885
            if source == "GroupMember.WORLD":
886
                group = torch.distributed.GroupMember.WORLD
887
            elif source == "group.WORLD":
888
                group = torch.distributed.group.WORLD
889
            else:
890
                assert source == "_get_default_group"
891
                group = torch.distributed.distributed_c10d._get_default_group()
892

893
            torch.distributed.all_reduce(
894
                tensor,
895
                group=group,
896
            )
897

898
        def verify(gm, _):
899
            ar_nodes = []
900
            for node in gm.graph.nodes:
901
                if node.target in [
902
                    torch.ops.c10d_functional.all_reduce,
903
                    torch.ops._c10d_functional.all_reduce,
904
                ]:
905
                    ar_nodes.append(node)
906
            self.assertEqual(len(ar_nodes), 1)
907
            return gm
908

909
        compiled = torch.compile(func, backend=verify, fullgraph=True)
910
        input = torch.ones(2, device=self.device)
911
        compiled(input)
912

913
    def test_dynamo_support_collective_op_with_async_op_False(self):
914
        def func(inp, out, *, pg):
915
            # user explicitly set the attribute `async_op` to False,
916
            # there should be no graph break
917
            torch.distributed.reduce_scatter_tensor(out, inp, group=pg, async_op=False)
918

919
        local_size = [4, 4]
920
        # single-proc test
921
        global_size = local_size
922

923
        inputs = torch.ones(local_size, device=self.device)
924
        outputs = torch.empty(global_size, device=self.device)
925
        correct_outputs = torch.empty(global_size, device=self.device)
926
        counter = CompileCounter()
927
        compiled = torch.compile(func, backend=counter)
928
        compiled(inputs, outputs, pg=GroupMember.WORLD)
929
        func(inputs, correct_outputs, pg=GroupMember.WORLD)
930
        assert counter.frame_count == 1
931
        assert counter.op_count == 3
932
        assert same(outputs, correct_outputs)
933

934
    def test_dynamo_graphbreaks_unsupported_async_op(self):
935
        def func(inp, out, *, pg):
936
            work = torch.distributed.reduce_scatter_tensor(
937
                out, inp, group=pg, async_op=True
938
            )
939
            work.wait()
940

941
        local_size = [4, 4]
942
        # single-proc test
943
        global_size = local_size
944

945
        inputs = torch.ones(local_size, device=self.device)
946
        outputs = torch.empty(global_size, device=self.device)
947
        correct_outputs = torch.empty(global_size, device=self.device)
948
        counter = CompileCounter()
949
        compiled = torch.compile(func, backend=counter)
950
        compiled(inputs, outputs, pg=GroupMember.WORLD)
951
        func(inputs, correct_outputs, pg=GroupMember.WORLD)
952
        assert counter.frame_count == 0
953
        assert counter.op_count == 0
954
        assert same(outputs, correct_outputs)
955

956
    def test_dynamo_pg_var(self):
957
        def func(inp, *, pg):
958
            x = pg.rank() + 1 % pg.size()
959
            return inp + x
960

961
        local_size = [4, 4]
962
        inputs = torch.ones(local_size, device=self.device)
963
        correct_outputs = torch.empty(local_size, device=self.device)
964
        counter = CompileCounter()
965
        compiled = torch.compile(func, backend=counter, fullgraph=True)
966
        outputs = compiled(inputs, pg=GroupMember.WORLD)
967
        correct_outputs = func(inputs, pg=GroupMember.WORLD)
968
        assert counter.frame_count == 1
969
        assert counter.op_count == 1
970
        assert same(outputs, correct_outputs)
971

972
    def test_dynamo_trace_reduce_scatter_tensor(self):
973
        def func(inp):
974
            ar = _functional_collectives.reduce_scatter_tensor(inp, "sum", 0, "0")
975
            return ar
976

977
        inputs = torch.ones(4, 4, device="cuda")
978
        counter = CompileCounter()
979
        compiled = torch.compile(func, backend=counter)
980
        out = compiled(inputs)
981
        correct = func(inputs)
982
        self.assertEqual(counter.frame_count, 1)
983

984
        # should test more precisely, but the 2 is supposed to be (reduce_scatter, wait)
985
        self.assertEqual(counter.op_count, 2)
986
        self.assertTrue(same(out, correct))
987

988
    def test_dynamo_trace_allgather_coalesced(self):
989
        def func(inp, *, tag, ranks, group_size):
990
            ar = torch.ops.c10d_functional.all_gather_into_tensor_coalesced(
991
                inp, tag, ranks, group_size
992
            )
993
            return ar
994

995
        inputs = [torch.ones(4, 4, device="cuda"), torch.ones(6, 6, device="cuda")]
996
        counter = CompileCounter()
997
        compiled = torch.compile(func, backend=counter)
998
        out = compiled(inputs, **self.get_world_trs())
999
        correct = func(inputs, **self.get_world_trs())
1000
        assert counter.frame_count == 1
1001
        assert counter.op_count == 3  # It generates 2 getattr to unpack the array
1002
        assert same(out, correct)
1003

1004
    def test_backwards(self):
1005
        """
1006
        It's probably not that common to need backwards support for collectives.
1007

1008
        However, I wanted to at least see if it was possible to support it as a design goal.
1009
        """
1010

1011
        def func(inp):
1012
            ar = _functional_collectives.all_reduce(inp, "sum", "0")
1013
            return ar
1014

1015
        input = torch.ones(4, 4, device="cuda", requires_grad=True)
1016
        # TODO implement backwards
1017
        with self.assertRaisesRegex(
1018
            RuntimeError,
1019
            "element 0 of tensors does not require grad and does not have a grad_fn",
1020
        ):
1021
            compiled = torch.compile(
1022
                func, backend="aot_eager"
1023
            )  # inductor bug with single-op allreduce graph
1024
            out = compiled(input)
1025
            out.sum().backward()
1026

1027
            correct_input = input.clone().detach().requires_grad_()
1028
            correct = func(correct_input)
1029
            correct.sum().backward()
1030
            self.assertTrue(same(out, correct))
1031
            self.assertTrue(same(input.grad, correct_input.grad))
1032

1033
    def test_meta(self):
1034
        x = torch.rand((2, 3, 4), device="meta")
1035
        out = torch.ops.c10d_functional.all_reduce(x, "sum", **self.get_world_trs())
1036
        self.assertEqual(x.size(), out.size())
1037

1038
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
1039
    @torch._inductor.config.patch({"debug": True, "triton.descriptive_names": False})
1040
    def test_inductor_all_gather_coalesced(self):
1041
        """
1042
        make sure that an intermediate that's going to be reuse isn't mutated unless copied
1043
        """
1044

1045
        def func(inp, *, tag, ranks, group_size):
1046
            x = inp + 1
1047
            tensor_list = torch.ops.c10d_functional.all_gather_into_tensor_coalesced(
1048
                [x, inp], tag, ranks, group_size
1049
            )
1050
            y = x + 2
1051
            ar0 = torch.ops.c10d_functional.wait_tensor(tensor_list[0])
1052
            ar1 = torch.ops.c10d_functional.wait_tensor(tensor_list[1])
1053
            # ensure other is not incorrectly aliasing ar's buffer
1054
            other = torch.ones_like(inp) + 22
1055
            return ar0, y, other, ar1
1056

1057
        inputs = torch.ones(4, 4, device="cuda")
1058

1059
        compiled = torch.compile(func)
1060
        code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
1061
        # NOTE: Make sure we are not unneccessarily copying the outputs of
1062
        # wait_tensors before they are returned from the graph.
1063
        (
1064
            FileCheck()
1065
            .check("buf0 = empty_strided")
1066
            .check("buf6 = empty_strided")
1067
            .check(".run(arg0_1, buf0, buf6, 16")
1068
            .check(
1069
                "buf1 = torch.ops._c10d_functional.all_gather_into_tensor_coalesced.default([buf0, arg0_1]"
1070
            )
1071
            .check("buf2 = buf1[0]")
1072
            .check("buf3 = buf1[1]")
1073
            .check("torch.ops._c10d_functional.wait_tensor.default(buf2")
1074
            .check("buf7 = buf0; del buf0  # reuse")
1075
            .check(".run(buf7, 16")
1076
            .check("torch.ops._c10d_functional.wait_tensor.default(buf3")
1077
            .check("return (buf2, buf6, buf7, buf3")
1078
            .run(code)
1079
        )
1080
        out = compiled(inputs, **self.get_world_trs())
1081
        correct = func(inputs, **self.get_world_trs())
1082
        assert same(out, correct), f"{out} va {correct}"
1083

1084
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
1085
    @torch._inductor.config.patch({"debug": True, "triton.descriptive_names": False})
1086
    def test_inductor_reduce_scatter_coalesced(self):
1087
        """
1088
        make sure that an intermediate that's going to be reuse isn't mutated unless copied
1089
        """
1090

1091
        def func(inp, *, tag, ranks, group_size):
1092
            x = inp + 1
1093
            tensor_list = torch.ops.c10d_functional.reduce_scatter_tensor_coalesced(
1094
                [x, inp], "sum", tag, ranks, group_size
1095
            )
1096
            y = x + 2
1097
            ar0 = torch.ops.c10d_functional.wait_tensor(tensor_list[0])
1098
            ar1 = torch.ops.c10d_functional.wait_tensor(tensor_list[1])
1099
            # ensure other is not incorrectly aliasing ar's buffer
1100
            other = torch.ones_like(inp) + 22
1101
            return ar0, y, other, ar1
1102

1103
        inputs = torch.ones(4, 4, device="cuda")
1104

1105
        compiled = torch.compile(func)
1106
        code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
1107
        # NOTE: The first return value should be the output of the first wait_tensor.
1108
        # We want to make sure no unneccessary copy is made.
1109
        (
1110
            FileCheck()
1111
            .check("buf0 = empty_strided")
1112
            .check("buf6 = empty_strided")
1113
            .check(".run(arg0_1, buf0, buf6, 16")
1114
            .check(
1115
                "buf1 = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced.default([buf0, arg0_1]"
1116
            )
1117
            .check("buf2 = buf1[0]")
1118
            .check("buf3 = buf1[1]")
1119
            .check("torch.ops._c10d_functional.wait_tensor.default(buf2")
1120
            .check("buf7 = buf0; del buf0  # reuse")
1121
            .check(".run(buf7, 16")
1122
            .check("torch.ops._c10d_functional.wait_tensor.default(buf3")
1123
            .check("return (buf2, buf6, buf7, buf3")
1124
            .run(code)
1125
        )
1126
        out = compiled(inputs, **self.get_world_trs())
1127
        correct = func(inputs, **self.get_world_trs())
1128
        assert same(out, correct), f"{out} va {correct}"
1129

1130

1131
if __name__ == "__main__":
1132
    from torch._dynamo.test_case import run_tests
1133

1134
    run_tests()
1135

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.