pytorch

Форк
0
/
test_functional_api.py 
844 строки · 28.9 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import os
4
import sys
5
import unittest
6
from functools import partial, wraps
7

8
import torch
9
import torch.distributed as dist
10
import torch.distributed._functional_collectives as ft_c
11
import torch.distributed._tensor as dt
12
import torch.distributed.distributed_c10d as c10d
13
from functorch import make_fx
14
from torch._inductor.utils import run_and_get_code
15
from torch.testing import FileCheck
16
from torch.testing._internal.distributed.fake_pg import FakeStore
17
from torch.utils._triton import has_triton
18

19

20
if not dist.is_available():
21
    print("Distributed not available, skipping tests", file=sys.stderr)
22
    sys.exit(0)
23

24
from torch.testing._internal.common_distributed import (
25
    MultiProcessTestCase,
26
    MultiThreadedTestCase,
27
    requires_nccl,
28
    TEST_SKIPS,
29
)
30
from torch.testing._internal.common_utils import (
31
    instantiate_parametrized_tests,
32
    parametrize,
33
    run_tests,
34
    TestCase,
35
)
36

37

38
def new_subgroups(group_size: int, pg_tag=None):
39
    world_size = dist.get_world_size()
40
    subgroups = []
41
    cur_subgroup = None
42

43
    for subgroup_id in range(world_size // group_size):
44
        start_rank = subgroup_id * group_size
45
        end_rank = start_rank + group_size
46
        ranks_in_subgroup = list(range(start_rank, end_rank))
47
        subgroup = c10d._new_group_with_tag(
48
            ranks=ranks_in_subgroup,
49
            pg_tag=pg_tag,
50
        )
51
        subgroups.append(subgroup)
52

53
        rank = dist.get_rank()
54
        if rank in ranks_in_subgroup:
55
            cur_subgroup = subgroup
56

57
    return cur_subgroup, subgroups
58

59

60
class TestExpand(MultiThreadedTestCase):
61
    @property
62
    def world_size(self):
63
        return 4
64

65
    def setUp(self):
66
        super().setUp()
67
        self._spawn_threads()
68

69
    def test_expand_1d_rank_list(self):
70
        tag, rankset, group_size = ft_c._expand_group([0, 1, 2, 3])
71
        self.assertEqual("", tag)
72
        self.assertEqual([0, 1, 2, 3], rankset)
73
        self.assertEqual(4, group_size)
74

75
        tag, rankset, group_size = ft_c._expand_group([0, 1, 2, 3], "bla")
76
        self.assertEqual("bla", tag)
77

78
    def test_expand_2d_rank_list(self):
79
        tag, rankset, group_size = ft_c._expand_group([[0, 1], [2, 3]])
80
        self.assertEqual("", tag)
81
        self.assertEqual([0, 1, 2, 3], rankset)
82
        self.assertEqual(2, group_size)
83

84
        tag, rankset, group_size = ft_c._expand_group([[0, 1], [2, 3]], "blu")
85
        self.assertEqual("blu", tag)
86

87
        with self.assertRaisesRegex(ValueError, "group sizes must be identical"):
88
            ft_c._expand_group([[0], [1, 2, 3]])
89

90
    def test_expand_process_group(self):
91
        tag, rankset, group_size = ft_c._expand_group(dist.group.WORLD)
92
        self.assertEqual(c10d._get_group_tag(dist.group.WORLD), tag)
93
        self.assertEqual([0, 1, 2, 3], rankset)
94
        self.assertEqual(4, group_size)
95

96
        tag, rankset, group_size = ft_c._expand_group(dist.group.WORLD, "bla")
97
        self.assertEqual("bla", tag)
98

99
        my_pg, others = new_subgroups(group_size=2)
100
        tag, rankset, group_size = ft_c._expand_group(my_pg)
101
        self.assertEqual(c10d._get_group_tag(my_pg), tag)
102
        self.assertEqual(dist.get_process_group_ranks(my_pg), rankset)
103
        self.assertEqual(2, group_size)
104

105
        my_pg = None
106
        for i in range(dist.get_world_size()):
107
            group = c10d._new_group_with_tag([i], pg_tag="my_pg")
108
            if i == dist.get_rank():
109
                my_pg = group
110
        tag, rankset, group_size = ft_c._expand_group(my_pg)
111
        self.assertEqual("my_pg", tag)
112
        self.assertEqual([dist.get_rank()], rankset)
113
        self.assertEqual(1, group_size)
114

115
        tag, rankset, group_size = ft_c._expand_group(my_pg, "bla")
116
        self.assertEqual("bla", tag)
117

118
    def test_expand_device_mesh(self):
119
        mesh = dt.DeviceMesh("cpu", torch.arange(4))
120
        tag, rankset, group_size = ft_c._expand_group(mesh)
121
        self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=0)), tag)
122
        self.assertEqual([0, 1, 2, 3], rankset)
123
        self.assertEqual(4, group_size)
124

125
        mesh = dt.DeviceMesh("cpu", torch.arange(4))
126
        tag, rankset, group_size = ft_c._expand_group(mesh)
127
        self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=0)), tag)
128
        self.assertEqual([0, 1, 2, 3], rankset)
129
        self.assertEqual(4, group_size)
130

131
    def test_expand_device_mesh_tuple(self):
132
        mesh = dt.DeviceMesh("cpu", torch.arange(4).view(2, 2))
133
        with self.assertRaisesRegex(AssertionError, "Only 1D mesh"):
134
            tag, rankset, group_size = ft_c._expand_group(mesh)
135

136
        tag, rankset, group_size = ft_c._expand_group((mesh, 0))
137
        self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=0)), tag)
138
        expected_rankset = [0, 2] if dist.get_rank() in [0, 2] else [1, 3]
139
        self.assertEqual(expected_rankset, rankset)
140
        self.assertEqual(2, group_size)
141

142
        tag, rankset, group_size = ft_c._expand_group((mesh, 1))
143
        expected_rankset = [0, 1] if dist.get_rank() in [0, 1] else [2, 3]
144
        self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=1)), tag)
145
        self.assertEqual(expected_rankset, rankset)
146
        self.assertEqual(2, group_size)
147

148

149
class TestPgTag(MultiThreadedTestCase):
150
    @property
151
    def world_size(self):
152
        return 4
153

154
    def setUp(self):
155
        super().setUp()
156
        self._spawn_threads()
157

158
    """
159
    The behavior we want is as follow:
160

161
    - rankset+tag will always result in the same PG.
162
    Do we enforce this by failing creation of new PGs or returning existing ones?
163
        Return existing one.
164

165
    - default tag gives existing behavior.
166
        This means we should create duplicates.
167
    - _expand_group on _default-tagged pg should always resolve to it
168
        This mean we can't depend on empty tag + rankset.
169
    """
170

171
    def test_pg_creation_with_tag(self):
172
        my_group, _ = new_subgroups(group_size=2, pg_tag="blu")
173
        my_group2, _ = new_subgroups(group_size=2, pg_tag="blu")
174
        self.assertEqual(my_group, my_group2)
175

176
        my_group3, _ = new_subgroups(group_size=2, pg_tag="blu2")
177
        self.assertNotEqual(my_group, my_group3)
178

179
        my_group4, _ = new_subgroups(group_size=2)
180
        self.assertNotEqual(my_group, my_group4)
181

182
        my_group5, _ = new_subgroups(group_size=2)
183
        self.assertNotEqual(my_group4, my_group5)
184

185
    def test_pg_lookup_roundtrip(self):
186
        pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")
187
        pg_tag1, _ = new_subgroups(group_size=2, pg_tag="blu2")
188
        pg_notag0, _ = new_subgroups(group_size=2)
189
        pg_notag1, _ = new_subgroups(group_size=2)
190

191
        def roundtrip(pg):
192
            tag, rankset, _ = ft_c._expand_group(pg)
193
            return c10d._find_pg_by_ranks_and_tag(tag, rankset)
194

195
        self.assertEqual(pg_tag0, roundtrip(pg_tag0))
196
        self.assertEqual(pg_tag1, roundtrip(pg_tag1))
197
        self.assertEqual(pg_notag0, roundtrip(pg_notag0))
198
        self.assertEqual(pg_notag1, roundtrip(pg_notag1))
199

200
    def test_pg_lookup_with_tag(self):
201
        pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")
202
        pg_tag1, _ = new_subgroups(group_size=2, pg_tag="bla")
203
        pg_notag0, _ = new_subgroups(group_size=2)
204

205
        def roundtrip(pg, pg_tag):
206
            tag, rankset, _ = ft_c._expand_group(pg, pg_tag)
207
            return c10d._find_pg_by_ranks_and_tag(tag, rankset)
208

209
        self.assertEqual(pg_tag0, roundtrip(pg_tag1, "blu"))
210
        self.assertEqual(pg_tag0, roundtrip(pg_notag0, "blu"))
211
        # Cannot erase the tag of a PG
212
        self.assertEqual(pg_tag0, roundtrip(pg_tag0, ""))
213

214
    def test_find_or_create_pg(self):
215
        pg = c10d._find_or_create_pg_by_ranks_and_tag("blu", [0, 1, 2, 3], 2)
216
        pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")
217
        self.assertEqual(pg, pg_tag0)
218

219
    def test_find_root_pg(self):
220
        pg = c10d._find_pg_by_ranks_and_tag("", [0, 1, 2, 3])
221
        self.assertEqual(dist.group.WORLD, pg)
222

223

224
@instantiate_parametrized_tests
225
class TestTraceableCollectives(MultiThreadedTestCase):
226
    @property
227
    def world_size(self):
228
        return 4
229

230
    def setUp(self):
231
        super().setUp()
232
        self._spawn_threads()
233

234
    @parametrize("device", ["cpu", "cuda"])
235
    def test_broadcast(self, device):
236
        if device == "cuda":
237
            if torch.cuda.device_count() < self.world_size:
238
                self.skipTest("Not enough CUDA devices")
239
            torch.cuda.set_device(dist.get_rank())
240

241
        if dist.get_rank() == 0:
242
            tensor = torch.ones([4], device=device)
243
        else:
244
            tensor = torch.zeros([4], device=device)
245

246
        mesh = dt.DeviceMesh(device, torch.arange(4))
247
        res = ft_c.broadcast(tensor, 0, mesh)
248
        self.assertEqual(res, torch.ones([4], device=device))
249

250
    @parametrize("device", ["cpu", "cuda"])
251
    def test_all_reduce_eager(self, device):
252
        if device == "cuda":
253
            if torch.cuda.device_count() < self.world_size:
254
                self.skipTest("Not enough CUDA devices")
255
            torch.cuda.set_device(dist.get_rank())
256

257
        tensor = torch.ones([4], device=device)
258
        mesh = dt.DeviceMesh(device, torch.arange(4))
259

260
        res = ft_c.all_reduce(tensor, "sum", mesh)
261
        self.assertEqual(res, torch.tensor([4, 4, 4, 4], dtype=torch.float))
262

263
        mesh = dt.DeviceMesh(device, torch.arange(4).view(2, 2))
264
        res2 = ft_c.all_reduce(tensor, "sum", (mesh, 1))
265
        self.assertEqual(res2, torch.tensor([2, 2, 2, 2], dtype=torch.float))
266

267
    @parametrize("device", ["cpu", "cuda"])
268
    def test_all_reduce_coalesced_eager(self, device):
269
        if device == "cuda":
270
            if torch.cuda.device_count() < self.world_size:
271
                self.skipTest("Not enough CUDA devices")
272
            torch.cuda.set_device(dist.get_rank())
273

274
        t0 = torch.ones([4], device=device)
275
        t1 = torch.ones([6], device=device) + 2
276
        mesh = dt.DeviceMesh(device, torch.arange(4))
277

278
        res = ft_c.all_reduce_coalesced([t0, t1], "sum", mesh)
279
        self.assertEqual(res[0], t0 * 4)
280
        self.assertEqual(res[1], t1 * 4)
281

282
    @parametrize("device", ["cpu", "cuda"])
283
    def test_all_gather_tensor(self, device):
284
        if device == "cuda":
285
            if torch.cuda.device_count() < self.world_size:
286
                self.skipTest("Not enough CUDA devices")
287
            torch.cuda.set_device(dist.get_rank())
288

289
        # testing 1d/2d mesh
290
        mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size))
291
        mesh_2d = dt.DeviceMesh(device, torch.arange(self.world_size).view(2, 2))
292
        for mesh in [mesh_1d, mesh_2d]:
293
            dims_to_gather = [0, 1, 2]
294
            for dim in dims_to_gather:
295
                output_size = [3, 3, 3]
296
                output_size[dim] *= mesh.size(0)
297
                # each rank have its own tensor, all_gather gives a bigger tensor
298
                local_tensor = torch.ones([3, 3, 3], device=device)
299
                gathered_tensor = ft_c.all_gather_tensor(
300
                    local_tensor, gather_dim=dim, group=(mesh, 0)
301
                )
302
                self.assertEqual(gathered_tensor, torch.ones(output_size))
303

304
    @parametrize("device", ["cpu", "cuda"])
305
    def test_all_gather_into_tensor_coalesced(self, device):
306
        if device == "cuda":
307
            if torch.cuda.device_count() < self.world_size:
308
                self.skipTest("Not enough CUDA devices")
309
            torch.cuda.set_device(dist.get_rank())
310

311
        tensors = [torch.ones([4], device=device), torch.ones([4], device=device) + 1]
312
        mesh = dt.DeviceMesh(device, torch.arange(4))
313

314
        res = ft_c.all_gather_into_tensor_coalesced(tensors, mesh)
315
        self.assertEqual(2, len(res))
316
        self.assertEqual(torch.ones([4 * dist.get_world_size()], device=device), res[0])
317
        self.assertEqual(
318
            torch.ones([4 * dist.get_world_size()], device=device) + 1, res[1]
319
        )
320

321
    @parametrize("device", ["cpu", "cuda"])
322
    def test_reduce_scatter_tensor(self, device):
323
        if device == "cuda":
324
            if torch.cuda.device_count() < self.world_size:
325
                self.skipTest("Not enough CUDA devices")
326
            torch.cuda.set_device(dist.get_rank())
327

328
        # testing 1d/2d mesh
329
        mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size))
330
        mesh_2d = dt.DeviceMesh(device, torch.arange(self.world_size).view(2, 2))
331
        for mesh in [mesh_1d, mesh_2d]:
332
            dims_to_scatter = [0, 1]
333
            for dim in dims_to_scatter:
334
                group_size = mesh.size(0)
335
                input_size = [3, 3]
336
                output_size = [3, 3]
337
                output_size[dim] *= group_size
338
                input_tensor = torch.ones(output_size, device=device)
339
                res_num = 1 * group_size
340
                rs_tensor = ft_c.reduce_scatter_tensor(
341
                    input_tensor, "sum", scatter_dim=dim, group=(mesh, 0)
342
                )
343
                self.assertEqual(rs_tensor, torch.ones(input_size) * res_num)
344

345
    @parametrize("device", ["cpu", "cuda"])
346
    def test_reduce_scatter_into_tensor_coalesced(self, device):
347
        if device == "cuda":
348
            if torch.cuda.device_count() < self.world_size:
349
                self.skipTest("Not enough CUDA devices")
350
            torch.cuda.set_device(dist.get_rank())
351
        tensors = [
352
            torch.ones([4], dtype=torch.int64, device=device),
353
            torch.ones([4], dtype=torch.int64, device=device) + 1,
354
        ]
355
        mesh = dt.DeviceMesh(device, torch.arange(4))
356

357
        res = ft_c.reduce_scatter_tensor_coalesced(tensors, "sum", [0, 0], mesh)
358
        self.assertEqual(2, len(res))
359
        self.assertEqual(torch.tensor([4], device=device), res[0])
360
        self.assertEqual(torch.tensor([8], device=device), res[1])
361

362

363
class TestMetaCollectives(TestCase):
364
    def test_all_reduce(self):
365
        x = torch.rand((2, 3, 4), device="meta")
366
        out = ft_c.all_reduce(x, "sum", "0")
367
        self.assertEqual(x.size(), out.size())
368

369

370
class TestGradCollectives(MultiThreadedTestCase):
371
    @property
372
    def world_size(self):
373
        return 2
374

375
    def setUp(self):
376
        super().setUp()
377
        self._spawn_threads()
378

379
    def test_all_reduce(self):
380
        x = torch.rand([4], requires_grad=True)
381
        y = torch.rand([4], requires_grad=True)
382
        out = ft_c.all_reduce(x, "sum", dist.group.WORLD)
383
        (out + y).sum().backward()
384
        self.assertIsNone(x.grad)
385

386

387
class TestMakeFx(TestCase):
388
    def setUp(self):
389
        # make_fx is not thread-safe due to patching nd mutating global states
390
        # so create a fake_pg.
391
        self.rank = 0
392
        self.world_size = 2
393
        store = FakeStore()
394
        dist.init_process_group(
395
            backend="fake",
396
            world_size=self.world_size,
397
            rank=self.rank,
398
            store=store,
399
        )
400

401
    def tearDown(self):
402
        super().tearDown()
403

404
        self.assertFalse(torch.fx._symbolic_trace.is_fx_tracing())
405

406
    def test_all_reduce_tracing(self):
407
        def allred(input):
408
            return ft_c.all_reduce(input, "sum", group=dist.group.WORLD) + 1
409

410
        graph = make_fx(allred)(torch.rand(4))
411
        FileCheck().check("all_reduce").check("wait_tensor").run(str(graph.graph))
412

413
        mesh = dt.DeviceMesh("cpu", torch.arange(self.world_size))
414

415
        def allred_mesh(input):
416
            return ft_c.all_reduce(input, "sum", mesh) + 1
417

418
        mesh_graph = make_fx(allred_mesh)(torch.rand(4))
419
        FileCheck().check_not("get_attr").check("wait_tensor").run(
420
            str(mesh_graph.graph)
421
        )
422

423
        def allred_mesh_dim(input):
424
            return ft_c.all_reduce(input, "sum", (mesh, 0)) + 1
425

426
        mesh_dim_graph = make_fx(allred_mesh_dim)(torch.rand(4))
427
        FileCheck().check_not("get_attr").check("wait_tensor").run(
428
            str(mesh_dim_graph.graph)
429
        )
430

431

432
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
433
WORLD_SIZE = 2
434

435

436
def exit_if_lt_x_gpu(x):
437
    if torch.cuda.device_count() < x:
438
        sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
439

440

441
def with_comms(func=None):
442
    if func is None:
443
        return partial(
444
            with_comms,
445
        )
446

447
    @wraps(func)
448
    def wrapper(self, *args, **kwargs):
449
        global BACKEND
450

451
        if "BACKEND" in os.environ:
452
            BACKEND = os.environ["BACKEND"]
453
        if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size:
454
            sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
455
        self.dist_init()
456
        func(self)
457
        self.destroy_comms()
458

459
    return wrapper
460

461

462
class TestCollectivesWithNCCL(MultiProcessTestCase):
463
    def setUp(self):
464
        super().setUp()
465
        os.environ["WORLD_SIZE"] = str(self.world_size)
466
        os.environ["BACKEND"] = dist.Backend.NCCL
467
        BACKEND = dist.Backend.NCCL
468
        self._spawn_processes()
469

470
    @property
471
    def device(self):
472
        return torch.device(self.rank)
473

474
    @property
475
    def world_size(self):
476
        return WORLD_SIZE
477

478
    @property
479
    def process_group(self):
480
        return dist.group.WORLD
481

482
    def dist_init(self):
483
        dist.init_process_group(
484
            backend=BACKEND,
485
            world_size=self.world_size,
486
            rank=self.rank,
487
            init_method=f"file://{self.file_name}",
488
        )
489

490
        # set device for nccl pg for collectives
491
        if BACKEND == "nccl":
492
            torch.cuda.set_device(self.rank)
493

494
    def destroy_comms(self):
495
        # Wait for all ranks to reach here before starting shutdown.
496
        dist.barrier()
497
        dist.destroy_process_group()
498

499
    @requires_nccl()
500
    @with_comms()
501
    def test_all_gather_into_tensor_coalesced(self):
502
        exit_if_lt_x_gpu(self.world_size)
503

504
        tensors = [
505
            torch.ones([4], device=f"cuda:{self.rank}"),
506
            torch.ones([4], device=f"cuda:{self.rank}") + 1,
507
        ]
508
        mesh = dt.DeviceMesh(f"cuda:{self.rank}", torch.arange(self.world_size))
509

510
        res = ft_c.all_gather_into_tensor_coalesced(tensors, mesh)
511
        self.assertEqual(2, len(res))
512
        self.assertEqual(torch.ones([4 * dist.get_world_size()]), res[0])
513
        self.assertEqual(torch.ones([4 * dist.get_world_size()]) + 1, res[1])
514

515
    @with_comms()
516
    def test_all_to_all_single(self):
517
        device = "cuda" if BACKEND == dist.Backend.NCCL else "cpu"
518
        mesh = dt.DeviceMesh(device, torch.arange(self.world_size))
519
        rank = dist.get_rank()
520

521
        row = self.world_size * (rank + 1) * (self.world_size + 1) / 2
522
        x = torch.ones(int(row), 5, device=device) * (rank + 1)
523
        split_sizes = [(i + 1) * (rank + 1) for i in range(self.world_size)]
524
        y = ft_c.all_to_all_single(
525
            x, output_split_sizes=split_sizes, input_split_sizes=split_sizes, group=mesh
526
        )
527
        expected = []
528
        for idx, tensor in enumerate(torch.split(x, split_sizes)):
529
            expected.append(torch.full_like(tensor, (idx + 1)))
530
        expected = torch.cat(expected)
531
        self.assertEqual(y, expected)
532

533
    @with_comms()
534
    def test_all_to_all_single_1d_input(self):
535
        device = "cuda" if BACKEND == dist.Backend.NCCL else "cpu"
536
        mesh = dt.DeviceMesh(device, torch.arange(self.world_size))
537
        rank = dist.get_rank()
538

539
        row = self.world_size * (rank + 1) * (self.world_size + 1) / 2
540
        x = torch.ones(int(row), device=device) * (rank + 1)
541
        split_sizes = [(i + 1) * (rank + 1) for i in range(self.world_size)]
542
        y = ft_c.all_to_all_single(
543
            x, output_split_sizes=split_sizes, input_split_sizes=split_sizes, group=mesh
544
        )
545
        expected = []
546
        for idx, tensor in enumerate(torch.split(x, split_sizes)):
547
            expected.append(torch.full_like(tensor, (idx + 1)))
548
        expected = torch.cat(expected)
549
        self.assertEqual(y, expected)
550

551
    @with_comms()
552
    def test_all_to_all_single_split_sizes_none(self):
553
        device = "cuda" if BACKEND == dist.Backend.NCCL else "cpu"
554
        mesh = dt.DeviceMesh(device, torch.arange(self.world_size))
555
        rank = dist.get_rank()
556

557
        x = torch.ones(self.world_size, self.world_size, device=device) * (rank + 1)
558
        y = ft_c.all_to_all_single(
559
            x, output_split_sizes=None, input_split_sizes=None, group=mesh
560
        )
561
        expected = []
562
        for idx, tensor in enumerate(torch.chunk(x, self.world_size)):
563
            expected.append(torch.full_like(tensor, (idx + 1)))
564
        expected = torch.cat(expected)
565
        self.assertEqual(y, expected)
566

567
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
568
    @requires_nccl()
569
    @with_comms()
570
    def test_tracing(self):
571
        def allreduce(t, pg):
572
            return ft_c.all_reduce(t, "sum", pg)
573

574
        compiled_allreduce = torch.compile(allreduce, fullgraph=True)
575
        compiled_allreduce(torch.randn(8, device=self.device), self.process_group)
576

577
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
578
    def test_tracing_with_fakepg(self):
579
        exit_if_lt_x_gpu(self.world_size)
580

581
        def allreduce(t, pg):
582
            return ft_c.all_reduce(t, "sum", pg)
583

584
        compiled_allreduce = torch.compile(allreduce, fullgraph=True)
585
        dist.init_process_group(
586
            backend="fake",
587
            rank=0,
588
            world_size=8,
589
            store=FakeStore(),
590
        )
591
        allreduce(torch.randn(8, device=self.device), pg=dist.group.WORLD)
592

593
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
594
    @requires_nccl()
595
    @with_comms()
596
    def test_tracing_with_dce_code(self):
597
        if self.world_size > 2:
598
            return
599

600
        def func(batch, group, rank):
601
            ret = ft_c.permute_tensor(batch, [1, 0], group)
602
            if hasattr(ret, "wait"):
603
                ret = ret.wait()
604
            if rank == 0:
605
                return ret
606
            else:
607
                return batch * 5
608

609
        compiled_func = torch.compile(func)
610
        ret = compiled_func(
611
            torch.ones((100,), device="cuda"), self.process_group, self.rank
612
        )
613
        dist.barrier()
614

615

616
class TestNCCLCollectivesWithWorldSize4(TestCollectivesWithNCCL):
617
    @property
618
    def world_size(self):
619
        return 4
620

621
    @requires_nccl()
622
    @with_comms()
623
    def test_permute_tensor_with_sub_group(self):
624
        exit_if_lt_x_gpu(self.world_size)
625

626
        device = "cuda"
627
        mesh_dim_names = ["dp", "tp"]
628

629
        mesh_2d = dt.init_device_mesh(
630
            device, (2, self.world_size // 2), mesh_dim_names=mesh_dim_names
631
        )
632

633
        for mesh_name in mesh_dim_names:
634
            mesh = mesh_2d[mesh_name]
635
            rank = mesh.get_local_rank()
636

637
            # rank0: [0., 1.], rank1: [2., 3.]
638
            send_tensor = torch.arange(2, dtype=torch.float32, device=device) + 2 * rank
639
            recvd_tensor = ft_c.permute_tensor(send_tensor, [1, 0], group=mesh)
640

641
            # rank0: [2., 3.], rank1: [0., 1.]
642
            expected = torch.arange(2, dtype=torch.float32, device=device) + 2 * (
643
                (rank - 1 + 2) % 2
644
            )
645
            self.assertEqual(
646
                recvd_tensor,
647
                expected,
648
                msg=f"Expected {expected} on {self.rank=} (local_rank={rank}), "
649
                f"but received {recvd_tensor} instead.",
650
            )
651

652

653
@instantiate_parametrized_tests
654
class TestFunctionalAutograd(MultiThreadedTestCase):
655
    def setUp(self):
656
        super().setUp()
657
        self._spawn_threads()
658

659
    @property
660
    def world_size(self):
661
        return 2
662

663
    @parametrize("compile", [True, False])
664
    def test_all_to_all_single(self, compile: bool = True) -> None:
665
        group = dist.group.WORLD.group_name
666

667
        t = torch.ones((self.world_size, 2), requires_grad=True)
668

669
        def my_func(t: torch.Tensor, world_size: int) -> torch.Tensor:
670
            sizes = [1] * world_size
671
            t = t * 2
672
            assert t.requires_grad
673
            out = ft_c.all_to_all_single_autograd(t, sizes, sizes, group)
674
            out = out + 0
675
            return out
676

677
        if compile:
678
            compiled = torch.compile(my_func, fullgraph=True, backend="aot_eager")
679
        else:
680
            compiled = my_func
681

682
        out = compiled(t, self.world_size)
683
        self.assertEqual(out.shape, t.shape)
684
        self.assertEqual(out, torch.full_like(t, 2.0))
685
        self.assertIsNotNone(out.grad_fn)
686
        self.assertTrue(out.requires_grad)
687
        loss = out.sum()
688
        loss.backward()
689
        self.assertEqual(t.grad, torch.full_like(t, 2.0))
690

691
    def test_all_to_all_single_inductor(self) -> None:
692
        group = dist.group.WORLD.group_name
693

694
        t = torch.rand((self.world_size, 2), requires_grad=True)
695

696
        def my_func(t: torch.Tensor, world_size: int) -> torch.Tensor:
697
            sizes = [1] * world_size
698
            t = t * 10
699
            assert t.requires_grad
700
            out = ft_c.all_to_all_single_autograd(t, sizes, sizes, group)
701
            out = out + 2
702
            return out.sum()
703

704
        compiled = torch.compile(my_func, fullgraph=True)
705

706
        def run_with_backward():
707
            out = compiled(t, self.world_size)
708
            out.backward()
709

710
        res, codes = run_and_get_code(run_with_backward)
711
        for code in codes:
712
            FileCheck().check_count(
713
                "_c10d_functional.all_to_all_single.default", 1, exactly=True
714
            ).check_count("_c10d_functional.wait_tensor.default", 1, exactly=True).run(
715
                code
716
            )
717

718
        self.assertIsNotNone(t.grad)
719

720
    @parametrize("compile", [True, False])
721
    def test_all_gather_tensor(self, compile: bool) -> None:
722
        group = dist.group.WORLD.group_name
723

724
        def my_func(t: torch.Tensor, dim: int) -> torch.Tensor:
725
            assert t.requires_grad
726
            out = ft_c.all_gather_tensor_autograd(
727
                t * 1.0,
728
                gather_dim=dim,
729
                group=group,
730
            )
731
            out = out * 1.0
732
            return out
733

734
        if compile:
735
            compiled = torch.compile(my_func, fullgraph=True, backend="aot_eager")
736
        else:
737
            compiled = my_func
738

739
        dims_to_gather = [0, 1, 2]
740
        for dim in dims_to_gather:
741
            output_size = [3, 3, 3]
742
            output_size[dim] *= self.world_size
743
            # each rank have its own tensor, all_gather gives a bigger tensor
744
            local_tensor = torch.ones([3, 3, 3], requires_grad=True)
745
            gathered_tensor = compiled(local_tensor, dim)
746
            self.assertEqual(gathered_tensor, torch.ones(output_size))
747

748
            gathered_tensor.sum().backward()
749
            self.assertEqual(
750
                local_tensor.grad,
751
                torch.full((3, 3, 3), fill_value=float(self.world_size)),
752
            )
753

754
    @parametrize("compile", [True, False])
755
    def test_reduce_scatter_tensor(self, compile: bool) -> None:
756
        group = dist.group.WORLD.group_name
757

758
        def my_func(t: torch.Tensor, dim: int) -> torch.Tensor:
759
            assert t.requires_grad
760
            rs_tensor = (
761
                ft_c.reduce_scatter_tensor_autograd(
762
                    input_tensor * 1.0, "sum", scatter_dim=dim, group=group
763
                )
764
                * 1.0
765
            )
766
            return rs_tensor
767

768
        if compile:
769
            compiled = torch.compile(my_func, fullgraph=True, backend="aot_eager")
770
        else:
771
            compiled = my_func
772

773
        dims_to_scatter = [0, 1]
774
        for dim in dims_to_scatter:
775
            group_size = self.world_size
776
            input_size = [3, 3]
777
            output_size = [3, 3]
778
            output_size[dim] *= group_size
779
            input_tensor = torch.ones(output_size, requires_grad=True)
780
            rs_tensor = compiled(input_tensor, dim)
781
            res_num = 1 * group_size
782
            self.assertEqual(rs_tensor, torch.ones(input_size) * res_num)
783
            rs_tensor.sum().backward()
784
            self.assertEqual(input_tensor.grad, torch.full(output_size, fill_value=1.0))
785

786

787
class TestFunctionalAutogradWithNCCL(MultiProcessTestCase):
788
    def setUp(self):
789
        super().setUp()
790
        os.environ["WORLD_SIZE"] = str(self.world_size)
791
        os.environ["BACKEND"] = dist.Backend.NCCL
792
        self._spawn_processes()
793

794
    @property
795
    def device(self):
796
        return torch.device(self.rank)
797

798
    @property
799
    def world_size(self):
800
        return 2
801

802
    @property
803
    def process_group(self):
804
        return dist.group.WORLD
805

806
    def dist_init(self):
807
        dist.init_process_group(
808
            backend=BACKEND,
809
            world_size=self.world_size,
810
            rank=self.rank,
811
            init_method=f"file://{self.file_name}",
812
        )
813

814
        # set device for nccl pg for collectives
815
        if BACKEND == "nccl":
816
            torch.cuda.set_device(self.rank)
817

818
    def destroy_comms(self):
819
        # Wait for all ranks to reach here before starting shutdown.
820
        dist.barrier()
821
        dist.destroy_process_group()
822

823
    @requires_nccl()
824
    @with_comms()
825
    def test_all_to_all_single(self) -> None:
826
        group = self.process_group.group_name
827

828
        t = torch.ones((self.world_size, 2), requires_grad=True, device=self.device)
829

830
        sizes = [1] * self.world_size
831
        assert t.requires_grad
832
        out = ft_c.all_to_all_single_autograd(t * 2, sizes, sizes, group) + 0
833

834
        self.assertEqual(out.shape, t.shape)
835
        self.assertEqual(out, torch.full_like(t, 2.0))
836
        self.assertIsNotNone(out.grad_fn)
837
        self.assertTrue(out.requires_grad)
838
        loss = out.sum()
839
        loss.backward()
840
        self.assertEqual(t.grad, torch.full_like(t, 2.0))
841

842

843
if __name__ == "__main__":
844
    run_tests()
845

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.