pytorch

Форк
0
/
test_c10d_functional_native.py 
784 строки · 27.0 Кб
1
# Owner(s): ["module: c10d"]
2
import unittest
3
from typing import List
4

5
import torch
6
import torch.distributed as dist
7
import torch.distributed._functional_collectives as funcol
8
from torch._C import FileCheck
9
from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code
10
from torch.distributed._functional_collectives import (
11
    all_gather_into_tensor_coalesced,
12
    all_gather_tensor,
13
    all_reduce,
14
    all_reduce_coalesced,
15
    all_to_all_single,
16
    AsyncCollectiveTensor,
17
    reduce_scatter_tensor,
18
    reduce_scatter_tensor_coalesced,
19
)
20
from torch.testing._internal.common_distributed import (
21
    MultiProcessTestCase,
22
    requires_nccl,
23
    run_with_native_funcol,
24
    skip_if_lt_x_gpu,
25
)
26
from torch.testing._internal.common_utils import (  # type: ignore[attr-defined]
27
    run_tests,
28
    TestCase,
29
)
30
from torch.testing._internal.distributed.fake_pg import FakeStore
31
from torch.utils._triton import has_triton
32

33

34
def load_test_module(name):
35
    import sys
36
    from importlib.machinery import SourceFileLoader
37
    from pathlib import Path
38
    from unittest import mock
39

40
    testdir = Path(__file__).absolute().parent.parent
41
    with mock.patch("sys.path", [*sys.path, str(testdir)]):
42
        return SourceFileLoader(
43
            name, str(testdir / f"{name.replace('.', '/')}.py")
44
        ).load_module()
45

46

47
AOTIRunnerUtil = load_test_module("inductor.test_aot_inductor_utils").AOTIRunnerUtil
48

49
import sys
50

51
if not dist.is_available():
52
    print("distributed package not available, skipping tests", file=sys.stderr)
53
    sys.exit(0)
54

55

56
@requires_nccl()
57
class C10DFunctionalNativeTest(MultiProcessTestCase):
58
    def setUp(self) -> None:
59
        super().setUp()
60
        self._spawn_processes()
61

62
    @property
63
    def world_size(self) -> int:
64
        return 2
65

66
    @property
67
    def ranks(self) -> List[int]:
68
        return list(range(self.world_size))
69

70
    @property
71
    def device(self) -> torch.device:
72
        return torch.device(f"cuda:{self.rank}")
73

74
    def _init_process_group(self) -> None:
75
        # Allow testing aoti after torch.compile
76
        torch._inductor.config.triton.store_cubin = True
77
        torch._inductor.config.debug = True
78

79
        torch.cuda.set_device(self.device)
80
        store = dist.FileStore(self.file_name, self.world_size)
81
        dist.init_process_group(
82
            backend="nccl",
83
            world_size=self.world_size,
84
            rank=self.rank,
85
            store=store,
86
        )
87
        torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)
88

89
    @skip_if_lt_x_gpu(2)
90
    @run_with_native_funcol
91
    def test_all_reduce_single(self) -> None:
92
        self._init_process_group()
93

94
        input = torch.full((10, 10), float(self.rank), device=self.device)
95
        output = torch.ops._c10d_functional.all_reduce(
96
            input,
97
            "avg",
98
            "default",
99
        )
100
        output = torch.ops._c10d_functional.wait_tensor(output)
101
        assert id(output) != id(input)
102
        expect = sum(self.ranks) / self.world_size
103
        assert output.eq(expect).all()
104

105
        # Test Python API and AsyncCollectiveTensor
106
        output = all_reduce(
107
            input,
108
            "avg",
109
            "default",
110
        )
111
        assert isinstance(output, AsyncCollectiveTensor)
112
        assert not output.completed
113
        assert output.eq(expect).all()
114
        assert output.completed
115

116
    @skip_if_lt_x_gpu(2)
117
    @run_with_native_funcol
118
    def test_all_reduce_single_(self) -> None:
119
        self._init_process_group()
120

121
        input = torch.full((10, 10), float(self.rank), device=self.device)
122
        output = torch.ops._c10d_functional.all_reduce_(
123
            input,
124
            "avg",
125
            "default",
126
        )
127
        output = torch.ops._c10d_functional.wait_tensor(output)
128
        assert id(output) == id(input)
129
        expect = sum(self.ranks) / self.world_size
130
        assert output.eq(expect).all()
131

132
    @skip_if_lt_x_gpu(2)
133
    @run_with_native_funcol
134
    def test_all_reduce_coalesced(self) -> None:
135
        self._init_process_group()
136

137
        inputs = [
138
            torch.full((i, i), float(self.rank * i), device=self.device)
139
            for i in range(10)
140
        ]
141
        outputs = torch.ops._c10d_functional.all_reduce_coalesced(
142
            inputs,
143
            "avg",
144
            "default",
145
        )
146
        for i, (output, input) in enumerate(zip(outputs, inputs)):
147
            output = torch.ops._c10d_functional.wait_tensor(output)
148
            assert id(output) != id(input)
149
            assert output.eq(sum(self.ranks) / self.world_size * i).all()
150

151
        # Test Python API and AsyncCollectiveTensor
152
        outputs = all_reduce_coalesced(
153
            inputs,
154
            "avg",
155
            "default",
156
        )
157
        for i, (output, input) in enumerate(zip(outputs, inputs)):
158
            assert not output.completed
159
            assert output.eq(sum(self.ranks) / self.world_size * i).all()
160
            assert output.completed
161

162
    @skip_if_lt_x_gpu(2)
163
    @run_with_native_funcol
164
    def test_all_reduce_coalesced_(self) -> None:
165
        self._init_process_group()
166

167
        inputs = [
168
            torch.full((i, i), float(self.rank * i), device=self.device)
169
            for i in range(10)
170
        ]
171
        outputs = torch.ops._c10d_functional.all_reduce_coalesced_(
172
            inputs,
173
            "avg",
174
            "default",
175
        )
176
        for i, (output, input) in enumerate(zip(outputs, inputs)):
177
            output = torch.ops._c10d_functional.wait_tensor(output)
178
            assert id(output) == id(input)
179
            assert output.eq(sum(self.ranks) / self.world_size * i).all()
180

181
    @skip_if_lt_x_gpu(2)
182
    @run_with_native_funcol
183
    def test_all_gather_into_tensor_single(self) -> None:
184
        self._init_process_group()
185

186
        input = torch.full((10, 10), float(self.rank), device=self.device)
187
        output = torch.ops._c10d_functional.all_gather_into_tensor(
188
            input,
189
            self.world_size,
190
            "default",
191
        )
192
        output = torch.ops._c10d_functional.wait_tensor(output)
193
        expect = torch.cat(
194
            [
195
                torch.full((10, 10), float(rank), device=self.device)
196
                for rank in self.ranks
197
            ]
198
        )
199
        assert torch.allclose(output, expect)
200
        assert output.eq(expect).all()
201

202
        # Test Python API and AsyncCollectiveTensor
203
        output = all_gather_tensor(
204
            input,
205
            0,
206
            "default",
207
        )
208
        assert isinstance(output, AsyncCollectiveTensor)
209
        assert not output.completed
210
        assert output.eq(expect).all()
211
        assert output.completed
212

213
    @skip_if_lt_x_gpu(2)
214
    @run_with_native_funcol
215
    def test_all_gather_into_tensor_coalesced(self) -> None:
216
        self._init_process_group()
217

218
        inputs = [
219
            torch.full((10, 10), float(self.rank * i), device=self.device)
220
            for i in range(10)
221
        ]
222
        outputs = torch.ops._c10d_functional.all_gather_into_tensor_coalesced(
223
            inputs,
224
            self.world_size,
225
            "default",
226
        )
227
        expect = [
228
            torch.cat(
229
                [
230
                    torch.full((10, 10), float(rank) * i, device=self.device)
231
                    for rank in self.ranks
232
                ]
233
            )
234
            for i in range(10)
235
        ]
236
        for i, output in enumerate(outputs):
237
            output = torch.ops._c10d_functional.wait_tensor(output)
238
            assert output.eq(expect[i]).all()
239

240
        # Test Python API and AsyncCollectiveTensor
241
        outputs = all_gather_into_tensor_coalesced(
242
            inputs,
243
            "default",
244
        )
245
        for i, output in enumerate(outputs):
246
            assert not output.completed
247
            assert output.eq(expect[i]).all()
248
            assert output.completed
249

250
    @skip_if_lt_x_gpu(2)
251
    @run_with_native_funcol
252
    def test_reduce_scatter_tensor_single(self) -> None:
253
        self._init_process_group()
254

255
        input = torch.tensor(self.ranks, device=self.device)
256
        output = torch.ops._c10d_functional.reduce_scatter_tensor(
257
            input,
258
            "avg",
259
            self.world_size,
260
            "default",
261
        )
262
        output = torch.ops._c10d_functional.wait_tensor(output)
263
        assert output.eq(self.rank).all()
264

265
        # Test Python API and AsyncCollectiveTensor
266
        output = reduce_scatter_tensor(
267
            input,
268
            "avg",
269
            0,
270
            "default",
271
        )
272
        assert isinstance(output, AsyncCollectiveTensor)
273
        assert not output.completed
274
        assert output.eq(self.rank).all()
275
        assert output.completed
276

277
    @skip_if_lt_x_gpu(2)
278
    @run_with_native_funcol
279
    def test_reduce_scatter_tensor_coalesced(self) -> None:
280
        self._init_process_group()
281

282
        inputs = [torch.tensor(self.ranks, device=self.device) * i for i in range(10)]
283
        outputs = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced(
284
            inputs,
285
            "avg",
286
            self.world_size,
287
            "default",
288
        )
289
        for i, output in enumerate(outputs):
290
            output = torch.ops._c10d_functional.wait_tensor(output)
291
            assert output.eq(self.rank * i).all()
292

293
        # Test Python API and AsyncCollectiveTensor
294
        outputs = reduce_scatter_tensor_coalesced(
295
            inputs,
296
            "avg",
297
            [0] * 10,
298
            "default",
299
        )
300
        for i, output in enumerate(outputs):
301
            assert not output.completed
302
            assert output.eq(self.rank * i).all()
303
            assert output.completed
304

305
    @skip_if_lt_x_gpu(2)
306
    @run_with_native_funcol
307
    def test_all_to_all_single(self) -> None:
308
        self._init_process_group()
309
        torch.cuda.set_device(self.device)
310

311
        torch.manual_seed(42)
312
        send_sz_matrix = torch.randint(0, 20, (self.world_size, self.world_size))
313

314
        input_split_sizes = send_sz_matrix[self.rank].tolist()
315
        output_split_sizes = send_sz_matrix[:, self.rank].tolist()
316
        input = torch.full((sum(input_split_sizes),), float(self.rank)).cuda()
317

318
        output = torch.ops._c10d_functional.all_to_all_single(
319
            input,
320
            output_split_sizes,
321
            input_split_sizes,
322
            "default",
323
        )
324
        output = torch.ops._c10d_functional.wait_tensor(output)
325
        expect = torch.cat(
326
            [
327
                torch.full((sz,), float(rank)).cuda()
328
                for rank, sz in enumerate(output_split_sizes)
329
            ]
330
        )
331
        assert output.eq(expect).all()
332

333
        # Test Python API and AsyncCollectiveTensor
334
        output = all_to_all_single(
335
            input, output_split_sizes, input_split_sizes, "default"
336
        )
337
        assert not output.completed
338
        assert output.eq(expect).all()
339
        assert output.completed
340

341
    @skip_if_lt_x_gpu(2)
342
    @run_with_native_funcol
343
    def test_broadcast(self) -> None:
344
        self._init_process_group()
345

346
        input = torch.full((10, 10), float(self.rank), device=self.device)
347
        output = torch.ops._c10d_functional.broadcast(
348
            input,
349
            1,
350
            "default",
351
        )
352
        output = torch.ops._c10d_functional.wait_tensor(output)
353
        assert id(output) != id(input)
354
        expect = 1
355
        assert output.eq(expect).all()
356

357
        # Test Python API and AsyncCollectiveTensor
358
        output = funcol.broadcast(
359
            input,
360
            1,
361
            "default",
362
        )
363
        assert isinstance(output, AsyncCollectiveTensor)
364
        assert not output.completed
365
        assert output.eq(expect).all()
366
        assert output.completed
367

368
    @skip_if_lt_x_gpu(2)
369
    @run_with_native_funcol
370
    def test_unwaited(self) -> None:
371
        # Verify that the process can terminate gracefully
372
        # even with unwaited tensors
373
        self._init_process_group()
374

375
        input = torch.full((10, 10), float(self.rank), device=self.device)
376
        output = torch.ops._c10d_functional.all_reduce(
377
            input,
378
            "avg",
379
            "default",
380
        )
381

382

383
class C10DFunctionalNativeCompileTest(TestCase):
384
    def setUp(self):
385
        # Allow testing aoti after torch.compile
386
        torch._inductor.config.triton.store_cubin = True
387
        torch._inductor.config.debug = True
388

389
        self.rank = 0
390
        self.world_size = 2
391
        torch.cuda.set_device("cuda:0")
392

393
        store = FakeStore()
394
        dist.init_process_group(
395
            backend="fake",
396
            world_size=self.world_size,
397
            rank=self.rank,
398
            store=store,
399
        )
400

401
    def tearDown(self):
402
        dist.destroy_process_group()
403

404
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
405
    @fresh_inductor_cache()
406
    @run_with_native_funcol
407
    def test_inductor_all_reduce_single(self):
408
        def func(arg: torch.Tensor) -> torch.Tensor:
409
            buf0 = arg + 42
410
            # Expect in-place with inductor allocated buf
411
            ar0 = funcol.all_reduce(buf0, "avg", "0")
412
            ar0 = funcol.wait_tensor(ar0)
413
            # Expect no in-place with graph input
414
            ar1 = funcol.all_reduce(arg, "avg", "0")
415
            ar1 = funcol.wait_tensor(ar1)
416
            return ar0, ar1
417

418
        arg = torch.rand(4, 4, device="cuda")
419
        compiled = torch.compile(func)
420

421
        code = run_and_get_triton_code(compiled, arg)
422
        (
423
            FileCheck()
424
            .check("buf0 = empty")
425
            .check("buf7 = empty")
426
            # Expect in-place with inductor allocated buf
427
            .check("torch.ops._c10d_functional.all_reduce_.default(buf0")
428
            .check("torch.ops._c10d_functional.wait_tensor.default(buf0")
429
            # Expect no in-place with graph input (buf5 is a clone)
430
            .check("torch.ops._c10d_functional.all_reduce_.default(buf7")
431
            .check("torch.ops._c10d_functional.wait_tensor.default(buf7")
432
            # Expect no extra copy on return
433
            .check("return (buf0, buf7, )")
434
            .run(code)
435
        )
436

437
        # Test aoti
438
        out = AOTIRunnerUtil.run("cuda", func, (arg,))
439
        torch.cuda.synchronize()
440

441
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
442
    @fresh_inductor_cache()
443
    @run_with_native_funcol
444
    def test_inductor_all_reduce_coalesced(self):
445
        def func(args: List[torch.Tensor]) -> torch.Tensor:
446
            bufs = [arg + 42 for arg in args]
447
            # Expect in-place with inductor allocated buf
448
            ar0 = funcol.all_reduce_coalesced(bufs, "avg", "0")
449
            ar0 = [funcol.wait_tensor(out) for out in ar0]
450
            # Expect no in-place with graph input
451
            ar1 = funcol.all_reduce_coalesced(args, "avg", "0")
452
            ar1 = [funcol.wait_tensor(out) for out in ar1]
453
            return ar0, ar1
454

455
        args = [torch.rand(4, 4, device="cuda") for _ in range(2)]
456
        compiled = torch.compile(func)
457
        code = run_and_get_triton_code(compiled, args)
458
        (
459
            FileCheck()
460
            .check("buf0 = empty")
461
            .check("buf5 = empty")
462
            .check("buf1 = empty")
463
            .check("buf6 = empty")
464
            # Expect in-place with inductor allocated buf
465
            .check(
466
                "torch.ops._c10d_functional.all_reduce_coalesced_"
467
                ".default([buf0, buf1]"
468
            )
469
            # Expect no in-place with graph input (buf5, buf6 are clones)
470
            .check(
471
                "torch.ops._c10d_functional.all_reduce_coalesced_"
472
                ".default([buf5, buf6]"
473
            )
474
            .check("torch.ops._c10d_functional.wait_tensor.default(buf0")
475
            .check("torch.ops._c10d_functional.wait_tensor.default(buf1")
476
            .check("torch.ops._c10d_functional.wait_tensor.default(buf5")
477
            .check("torch.ops._c10d_functional.wait_tensor.default(buf6")
478
            # Expect no extra copy on return
479
            .check("return (buf0, buf1, buf5, buf6, )")
480
            .run(code)
481
        )
482

483
        # Test aoti
484
        out = AOTIRunnerUtil.run("cuda", func, (args,))
485
        torch.cuda.synchronize()
486

487
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
488
    @fresh_inductor_cache()
489
    @run_with_native_funcol
490
    def test_inductor_inplace_op_on_view(self):
491
        def func(arg: torch.Tensor) -> torch.Tensor:
492
            buf0 = (arg + 10)[:2]
493
            ar0 = funcol.all_reduce(buf0, "avg", "0")
494
            ar0 = funcol.wait_tensor(ar0)
495
            return ar0
496

497
        arg = torch.rand(4, 4, device="cuda")
498
        compiled = torch.compile(func)
499

500
        code = run_and_get_triton_code(compiled, arg)
501
        (
502
            FileCheck()
503
            .check("buf0 = empty")
504
            # Ensure the all_reduce_ input is a view
505
            .check(
506
                "torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf0"
507
            )
508
            .check(
509
                "torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf0"
510
            )
511
            .check("return (reinterpret_tensor(buf0")
512
            .run(code)
513
        )
514

515
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
516
    @fresh_inductor_cache()
517
    @run_with_native_funcol
518
    def test_inductor_reuse_buffer_after_inplace_collective(self):
519
        def func(arg: torch.Tensor) -> torch.Tensor:
520
            # Expect allocation
521
            buf0 = arg + 42
522
            ar0 = funcol.all_reduce(buf0, "avg", "0")
523
            ar0 = funcol.wait_tensor(ar0)
524
            # Expect allocation
525
            buf1 = torch.mm(arg, ar0)
526
            # Expect buf0 to be reused
527
            buf2 = torch.mm(arg, buf1)
528
            return buf1, buf2
529

530
        arg = torch.rand(4, 4, device="cuda")
531
        compiled = torch.compile(func)
532
        code = run_and_get_triton_code(compiled, arg)
533
        (
534
            FileCheck()
535
            # Expect allocation
536
            .check("buf0 = empty")
537
            .check("torch.ops._c10d_functional.all_reduce_.default(buf0")
538
            .check("torch.ops._c10d_functional.wait_tensor.default(buf0")
539
            # Expect allocation
540
            .check("buf7 = empty")
541
            .check("extern_kernels.mm(arg0_1, buf0, out=buf7")
542
            # Expect buf0 to be reused
543
            .check("buf8 = buf0; del buf0  # reuse")
544
            .check("extern_kernels.mm(arg0_1, buf7, out=buf8")
545
            # Expect no extra copy on return
546
            .check("return (buf7, buf8, )")
547
            .run(code)
548
        )
549

550
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
551
    @fresh_inductor_cache()
552
    @run_with_native_funcol
553
    def test_inductor_all_gather_into_tensor_single(self):
554
        def func(arg: torch.Tensor) -> torch.Tensor:
555
            ag0 = funcol.all_gather_tensor(arg, 0, "0")
556
            ag0 = funcol.wait_tensor(ag0)
557
            return ag0
558

559
        arg = torch.rand(4, 4, device="cuda")
560
        compiled = torch.compile(func)
561
        code = run_and_get_triton_code(compiled, arg)
562
        (
563
            FileCheck()
564
            .check(
565
                "buf0 = torch.ops._c10d_functional.all_gather_into_tensor.default(arg0_1"
566
            )
567
            .check("torch.ops._c10d_functional.wait_tensor.default(buf0")
568
            # Expect no extra copy on return
569
            .check("return (buf0, )")
570
            .run(code)
571
        )
572

573
        # Test aoti
574
        out = AOTIRunnerUtil.run("cuda", func, (arg,))
575
        torch.cuda.synchronize()
576

577
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
578
    @fresh_inductor_cache()
579
    @run_with_native_funcol
580
    def test_inductor_all_gather_into_tensor_coalesced(self):
581
        def func(args: List[torch.Tensor]) -> torch.Tensor:
582
            ag0 = funcol.all_gather_into_tensor_coalesced(args, "0")
583
            ag0 = [funcol.wait_tensor(out) for out in ag0]
584
            return ag0
585

586
        args = [torch.rand(4, 4, device="cuda") for _ in range(4)]
587
        compiled = torch.compile(func)
588
        code = run_and_get_triton_code(compiled, args)
589
        (
590
            FileCheck()
591
            .check(
592
                "buf0 = torch.ops._c10d_functional.all_gather_into_tensor_coalesced"
593
                ".default([arg0_1, arg1_1, arg2_1, arg3_1]"
594
            )
595
            .check("buf1 = buf0[0]")
596
            .check("buf2 = buf0[1]")
597
            .check("buf3 = buf0[2]")
598
            .check("buf4 = buf0[3]")
599
            .check("torch.ops._c10d_functional.wait_tensor.default(buf1")
600
            .check("torch.ops._c10d_functional.wait_tensor.default(buf2")
601
            .check("torch.ops._c10d_functional.wait_tensor.default(buf3")
602
            .check("torch.ops._c10d_functional.wait_tensor.default(buf4")
603
            # Expect no extra copy on return
604
            .check("return (buf1, buf2, buf3, buf4, )")
605
            .run(code)
606
        )
607

608
        # Test aoti
609
        out = AOTIRunnerUtil.run("cuda", func, (args,))
610
        torch.cuda.synchronize()
611

612
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
613
    @fresh_inductor_cache()
614
    @run_with_native_funcol
615
    def test_inductor_reduce_scatter_tensor_single(self):
616
        def func(arg: torch.Tensor) -> torch.Tensor:
617
            rs0 = funcol.reduce_scatter_tensor(arg, "avg", 0, "0")
618
            rs0 = funcol.wait_tensor(rs0)
619
            return rs0
620

621
        arg = torch.rand(4, 4, device="cuda")
622
        compiled = torch.compile(func)
623
        code = run_and_get_triton_code(compiled, arg)
624
        (
625
            FileCheck()
626
            .check(
627
                "buf0 = torch.ops._c10d_functional.reduce_scatter_tensor.default(arg0_1"
628
            )
629
            .check("torch.ops._c10d_functional.wait_tensor.default(buf0")
630
            # Expect no extra copy on return
631
            .check("return (buf0, )")
632
            .run(code)
633
        )
634

635
        # Test aoti
636
        out = AOTIRunnerUtil.run("cuda", func, (arg,))
637
        torch.cuda.synchronize()
638

639
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
640
    @fresh_inductor_cache()
641
    @run_with_native_funcol
642
    def test_inductor_reduce_scatter_tensor_coalesced(self):
643
        def func(args: List[torch.Tensor]) -> torch.Tensor:
644
            rs0 = funcol.reduce_scatter_tensor_coalesced(
645
                args, "avg", [0] * len(args), "0"
646
            )
647
            rs0 = [funcol.wait_tensor(out) for out in rs0]
648
            return rs0
649

650
        args = [torch.rand(4, 4, device="cuda") for _ in range(4)]
651
        compiled = torch.compile(func)
652
        code = run_and_get_triton_code(compiled, args)
653
        (
654
            FileCheck()
655
            .check(
656
                "buf0 = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced"
657
                ".default([arg0_1, arg1_1, arg2_1, arg3_1]"
658
            )
659
            .check("buf1 = buf0[0]")
660
            .check("buf2 = buf0[1]")
661
            .check("buf3 = buf0[2]")
662
            .check("buf4 = buf0[3]")
663
            .check("torch.ops._c10d_functional.wait_tensor.default(buf1")
664
            .check("torch.ops._c10d_functional.wait_tensor.default(buf2")
665
            .check("torch.ops._c10d_functional.wait_tensor.default(buf3")
666
            .check("torch.ops._c10d_functional.wait_tensor.default(buf4")
667
            # Expect no extra copy on return
668
            .check("return (buf1, buf2, buf3, buf4, )")
669
            .run(code)
670
        )
671

672
        # Test aoti
673
        AOTIRunnerUtil.run("cuda", func, (args,))
674
        torch.cuda.synchronize()
675

676
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
677
    @fresh_inductor_cache()
678
    @run_with_native_funcol
679
    def test_inductor_all_to_all_single(self):
680
        def _tolist_with_constrain_as_size(tensor):
681
            lst = tensor.tolist()
682
            for elem in lst:
683
                torch._constrain_as_size(elem)
684
            return lst
685

686
        def func(
687
            input: torch.Tensor,
688
            output_split_sizes: torch.Tensor,
689
            input_split_sizes: torch.Tensor,
690
        ) -> torch.Tensor:
691
            output = funcol.all_to_all_single(
692
                input,
693
                _tolist_with_constrain_as_size(output_split_sizes),
694
                _tolist_with_constrain_as_size(input_split_sizes),
695
                "0",
696
            )
697
            return funcol.wait_tensor(output)
698

699
        torch.manual_seed(42)
700
        send_sz_matrix = torch.randint(0, 20, (self.world_size, self.world_size))
701

702
        input_split_sizes = send_sz_matrix[self.rank]
703
        output_split_sizes = send_sz_matrix[:, self.rank].contiguous()
704
        input = torch.full((input_split_sizes.sum().item(),), float(self.rank)).cuda()
705

706
        with torch._dynamo.config.patch(
707
            dynamic_shapes=True,
708
            capture_dynamic_output_shape_ops=True,
709
            capture_scalar_outputs=True,
710
        ):
711
            compiled = torch.compile(func, dynamic=True)
712
            code = run_and_get_triton_code(
713
                compiled, input, output_split_sizes, input_split_sizes
714
            )
715
        (
716
            FileCheck()
717
            .check_regex(
718
                "torch.ops._c10d_functional.all_to_all_single.default\\("
719
                "arg\\d+_\\d+, \\[u\\d+, u\\d+\\], \\[u\\d+, u\\d+\\]"
720
            )
721
            .check("torch.ops._c10d_functional.wait_tensor.default(")
722
            .run(code)
723
        )
724

725
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
726
    @fresh_inductor_cache()
727
    @run_with_native_funcol
728
    def test_inductor_broadcast(self):
729
        def func(arg: torch.Tensor) -> torch.Tensor:
730
            buf0 = arg + 42
731
            # Expect in-place with inductor allocated buf
732
            br0 = funcol.broadcast(buf0, 1, "0")
733
            br0 = funcol.wait_tensor(br0)
734
            # Expect no in-place with graph input
735
            br1 = funcol.broadcast(arg, 0, "0")
736
            br1 = funcol.wait_tensor(br1)
737
            return br0, br1
738

739
        arg = torch.rand(4, 4, device="cuda")
740
        compiled = torch.compile(func)
741

742
        code = run_and_get_triton_code(compiled, arg)
743
        (
744
            FileCheck()
745
            .check("buf0 = empty")
746
            .check("buf7 = empty")
747
            # Expect in-place with inductor allocated buf
748
            .check("torch.ops._c10d_functional.broadcast_.default(buf0")
749
            .check("torch.ops._c10d_functional.wait_tensor.default(buf0")
750
            # Expect no in-place with graph input (buf5 is a clone)
751
            .check("torch.ops._c10d_functional.broadcast_.default(buf7")
752
            .check("torch.ops._c10d_functional.wait_tensor.default(buf7")
753
            # Expect no extra copy on return
754
            .check("return (buf0, buf7, )")
755
            .run(code)
756
        )
757

758
        # Test aoti
759
        out = AOTIRunnerUtil.run("cuda", func, (arg,))
760
        torch.cuda.synchronize()
761

762
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
763
    @fresh_inductor_cache()
764
    @run_with_native_funcol
765
    def test_ranks_and_tag(self):
766
        def func(arg: torch.Tensor) -> torch.Tensor:
767
            buf0 = arg + 42
768
            # Expect in-place with inductor allocated buf
769
            ar0 = funcol.all_reduce(buf0, "avg", [0, 1], "")
770
            ar0 = funcol.wait_tensor(ar0)
771
            # Expect no in-place with graph input
772
            ar1 = funcol.all_reduce(arg, "avg", [0, 1], "")
773
            ar1 = funcol.wait_tensor(ar1)
774
            return ar0, ar1
775

776
        arg = torch.rand(4, 4, device="cuda")
777
        compiled = torch.compile(func, fullgraph=True)
778

779
        code = run_and_get_triton_code(compiled, arg)
780
        (FileCheck().check("all_reduce_.default(buf0, 'avg', '0')").run(code))
781

782

783
if __name__ == "__main__":
784
    run_tests()
785

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.