pytorch

Форк
0
/
test_c10d_ops_nccl.py 
996 строк · 35.4 Кб
1
# Owner(s): ["oncall: distributed"]
2
# This test file contains positive tests for c10d with NCCL backend.
3
# During the test, it is expected that ProcessGroup will not be aborted, destroyed or incur fatal error.
4
# Please be mindful of this when adding tests here.
5
# If you need to add tests for group creation, abort or destroy, please add tests in test_c10d_nccl.py.
6

7
# There are two ways to launch tests in this file:
8
# 1. Run this file directly with `python test_c10d_ops_nccl.py`
9
# 2. Use multi-process launcher, e.g. `torchrun --standalone --nproc-per-node 2 test_c10d_ops_nccl.py`
10

11
import math
12
import os
13
import sys
14
import tempfile
15

16
import torch
17
import torch.distributed as c10d
18

19

20
if not c10d.is_available() or not c10d.is_nccl_available():
21
    print("c10d NCCL not available, skipping tests", file=sys.stderr)
22
    sys.exit(0)
23

24

25
import torch.distributed as dist
26
from torch.testing._internal.common_cuda import TEST_MULTIGPU
27
from torch.testing._internal.common_distributed import (
28
    init_multigpu_helper,
29
    MultiProcContinousTest,
30
    requires_nccl,
31
)
32
from torch.testing._internal.common_utils import (
33
    skip_but_pass_in_sandcastle_if,
34
    skipIfRocm,
35
    TEST_WITH_DEV_DBG_ASAN,
36
)
37

38

39
if TEST_WITH_DEV_DBG_ASAN:
40
    print(
41
        "Skip ASAN as torch + multiprocessing spawn have known issues", file=sys.stderr
42
    )
43
    sys.exit(0)
44

45

46
class ProcessGroupNCCLOpTest(MultiProcContinousTest):
47
    @classmethod
48
    def backend_str(cls) -> str:
49
        return "nccl"
50

51
    @classmethod
52
    def opts(cls, high_priority_stream=False):
53
        opts = c10d.ProcessGroupNCCL.Options()
54
        opts.is_high_priority_stream = high_priority_stream
55
        return opts
56

57
    @property
58
    def rank_to_GPU(self):
59
        # return rank to GPU map
60
        return init_multigpu_helper(self.world_size, "nccl")
61

62
    @requires_nccl()
63
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
64
    def test_empty_tensors(self):
65
        pg = self.pg
66
        local_device_idx = self.rank_to_GPU[self.rank][0]
67

68
        xs = [torch.FloatTensor([]).cuda(local_device_idx)]
69
        pg.broadcast(xs).wait()
70
        self.assertEqual(0, xs[0].numel())
71

72
        pg.allreduce(xs).wait()
73
        self.assertEqual(0, xs[0].numel())
74

75
        pg.reduce(xs).wait()
76
        self.assertEqual(0, xs[0].numel())
77

78
        ys = [
79
            [
80
                torch.FloatTensor([]).cuda(local_device_idx)
81
                for _ in range(self.world_size)
82
            ]
83
        ]
84
        pg.allgather(ys, xs).wait()
85
        for y in ys[0]:
86
            self.assertEqual(0, y.numel())
87

88
        ys = [torch.FloatTensor([]).cuda(local_device_idx)]
89
        xs = [
90
            [
91
                torch.FloatTensor([]).cuda(local_device_idx)
92
                for _ in range(self.world_size)
93
            ]
94
        ]
95
        pg.reduce_scatter(ys, xs).wait()
96
        self.assertEqual(0, ys[0].numel())
97

98
    @requires_nccl()
99
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
100
    def test_broadcast_ops(self):
101
        pg = self.pg
102

103
        def broadcast(xs, rootRank, rootTensor):
104
            opts = c10d.BroadcastOptions()
105
            opts.rootRank = rootRank
106
            opts.rootTensor = rootTensor
107
            work = pg.broadcast(xs, opts)
108
            work.wait()
109
            return xs
110

111
        # Every rank is root once
112
        for i in range(self.world_size):
113
            # Run with 1 input tensor
114
            x = torch.tensor([self.rank]).cuda(self.rank_to_GPU[self.rank][0])
115
            output = broadcast([x], i, 0)
116
            self.assertEqual(torch.tensor([i]), output[0])
117

118
            expected_tensor = torch.empty([i + 1, i + 1]).fill_(i + 1)
119
            xs = [
120
                torch.empty([i + 1, i + 1]).fill_(-1).cuda(device=device_idx)
121
                for device_idx in self.rank_to_GPU[self.rank]
122
            ]
123

124
            # test with multiple input tensors (multiple gpu in one rank)
125
            for j in range(len(xs)):
126
                if self.rank == i:
127
                    xs[j] = expected_tensor.cuda(device=self.rank_to_GPU[self.rank][j])
128

129
                broadcast(xs, i, j)
130

131
                for tensor in xs:
132
                    self.assertEqual(tensor, expected_tensor)
133

134
    @requires_nccl()
135
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
136
    def test_sparse_allreduce_ops(self):
137
        pg = self.pg
138

139
        indices = torch.tensor([[0, 1]])
140
        values = torch.tensor([[1, 2, 0], [4, 0, 6]])
141
        sparse_tensor = torch.sparse_coo_tensor(indices, values, size=(2, 3)).to(
142
            self.rank
143
        )
144

145
        # sparse allreduce call is wrapped in a try catch since the c10d API is only available in the nccl experimental branch
146
        try:
147
            tensor_list = [sparse_tensor]
148
            work = pg.allreduce(tensor_list)
149
            work.wait()
150

151
            # tensor_list is a list of size 1, with the allreduce output as a dense tensor
152
            a = torch.tensor([[2, 4, 0], [8, 0, 12]]).to(self.rank)
153
            self.assertEqual(tensor_list[0], a)
154
        except RuntimeError as e:
155
            if "NCCL does not support all_reduce with sparse tensors" in str(e):
156
                pass
157
            else:
158
                # Rethrow the exception if it's a different error
159
                raise
160

161
    @requires_nccl()
162
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
163
    def test_allreduce_ops(self):
164
        device_count = torch.cuda.device_count()
165
        pg = self.pg
166
        local_device_id = self.rank_to_GPU[self.rank][0]
167

168
        def allreduce(tensors, op):
169
            opts = c10d.AllreduceOptions()
170
            opts.reduceOp = op
171
            work = pg.allreduce(tensors, opts)
172
            work.wait()
173

174
        # Sum
175
        tensors = [torch.tensor([self.rank + 1]).cuda(local_device_id)]
176

177
        allreduce(tensors, c10d.ReduceOp.SUM)
178

179
        ndev = self.world_size
180
        self.assertEqual(
181
            torch.tensor([ndev * (ndev + 1) // 2]),
182
            tensors[0],
183
        )
184

185
        # Avg (only available for NCCL 2.10+)
186
        if torch.cuda.nccl.version() >= (2, 10, 0):
187
            tensors = [torch.tensor([self.rank + 1.0]).cuda(local_device_id)]
188

189
            allreduce(tensors, c10d.ReduceOp.AVG)
190
            ndev = self.world_size
191
            self.assertEqual(
192
                torch.tensor([ndev * (ndev + 1.0) / (2.0 * ndev)]),
193
                tensors[0],
194
            )
195

196
        # Premul Sum
197
        if torch.cuda.nccl.version() >= (2, 11, 1):
198
            for dtype in torch.half, torch.float, torch.double:
199
                for factor in (
200
                    3.0,
201
                    torch.tensor([5.0], device=local_device_id, dtype=dtype),
202
                ):
203
                    tensors = [
204
                        torch.tensor([self.rank + 1])
205
                        .cuda(local_device_id)
206
                        .to(dtype=dtype)
207
                    ]
208

209
                    allreduce(tensors, c10d._make_nccl_premul_sum(factor))
210

211
                    self.assertEqual(
212
                        factor
213
                        * torch.tensor(
214
                            [self.world_size * (self.world_size + 1) / 2],
215
                            dtype=dtype,
216
                            device=local_device_id,
217
                        ),
218
                        tensors[0],
219
                    )
220

221
        # Product
222
        tensors = [torch.tensor([self.rank + 1]).cuda(local_device_id)]
223

224
        allreduce(tensors, c10d.ReduceOp.PRODUCT)
225
        self.assertEqual(torch.tensor([math.factorial(self.world_size)]), tensors[0])
226

227
        # Min
228
        tensors = [torch.tensor([self.rank + 1]).cuda(local_device_id)]
229

230
        allreduce(tensors, c10d.ReduceOp.MIN)
231
        self.assertEqual(torch.tensor([1]), tensors[0])
232

233
        # Max
234
        tensors = [torch.tensor([self.rank + 1]).cuda(local_device_id)]
235

236
        allreduce(tensors, c10d.ReduceOp.MAX)
237
        self.assertEqual(torch.tensor([self.world_size]), tensors[0])
238

239
        for op, err in zip(
240
            (c10d.ReduceOp.BAND, c10d.ReduceOp.BOR, c10d.ReduceOp.BXOR),
241
            ("ReduceOp.BAND", "ReduceOp.BOR", "ReduceOp.BXOR"),
242
        ):
243
            with self.assertRaisesRegex(ValueError, "Cannot use " + err + " with NCCL"):
244
                allreduce(tensors, op)
245

246
    @requires_nccl()
247
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
248
    def test_alltoall_ops_with_cudafree_race(self):
249
        pg = self.pg
250
        opts = c10d.AllToAllOptions()
251
        local_device = f"cuda:{self.rank_to_GPU[self.rank][0]}"
252
        torch.cuda.set_device(local_device)
253
        input = torch.rand(1000, 1000, device=local_device)
254
        output = torch.rand(1000, 1000, device=local_device)
255
        race_tensors = []
256
        # create some tensors to race with alltoall collective
257
        for _ in range(10):
258
            tmp = []
259
            for i in range(5):
260
                tmp.append(torch.rand(10 ** (3 + i), device=local_device))
261
            race_tensors.append(tmp)
262

263
        for i in range(10):
264
            race_tensors.pop()
265
            work = pg.alltoall_base(output, input, [], [], opts)
266
            # this triggers cudaFree
267
            torch.cuda.empty_cache()
268
            work.wait()
269
        torch.cuda.synchronize(device=local_device)
270

271
    @requires_nccl()
272
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
273
    def test_allreduce_in_cudagraph(self):
274
        pg = self.pg
275
        local_device_idx = self.rank_to_GPU[self.rank][0]
276
        with torch.cuda.device(local_device_idx):
277
            xs = [torch.FloatTensor([1]).cuda(local_device_idx)]
278

279
            # single warmup
280
            pg.allreduce(xs).wait()
281
            self.assertEqual(xs[0].item(), 2)
282

283
            graph = torch.cuda.CUDAGraph()
284
            with torch.cuda.graph(graph):
285
                pg.allreduce(xs).wait()
286
            self.assertEqual(xs[0].item(), 2)
287

288
            graph.replay()
289
            graph.replay()
290
            self.assertEqual(xs[0].item(), 8)
291

292
    @requires_nccl()
293
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
294
    @skipIfRocm()
295
    def test_nccl_watchdog_cudagraph(self):
296
        # test that the watchdog does not crash graphs with disallowed event query
297
        pg = self.pg
298
        rank = self.rank_to_GPU[self.rank][0]
299
        with torch.cuda.device(rank):
300
            for i in range(10):
301
                xs = [torch.FloatTensor([1]).cuda(rank)]
302
                ys = [torch.FloatTensor([4]).cuda(rank)]
303
                for _ in range(30):
304
                    pg.allreduce(xs[0]).wait()
305

306
                graph = torch.cuda.CUDAGraph()
307
                with torch.cuda.graph(graph):
308
                    xs[0] += 0.0
309
                    pg.allreduce(xs[0]).wait()
310
                    pg.allreduce(xs[0]).wait()
311
                    pg.allreduce(xs[0]).wait()
312
                    xs[0] += 0.0
313

314
                for _ in range(100):
315
                    graph.replay()
316

317
    @requires_nccl()
318
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
319
    def test_reduce_ops(self):
320
        pg = self.pg
321
        local_device_id = self.rank_to_GPU[self.rank][0]
322

323
        def reduce(xs, rootRank, rootTensor, op=None):
324
            opts = c10d.ReduceOptions()
325
            opts.rootRank = rootRank
326
            opts.rootTensor = rootTensor
327
            if op:
328
                opts.reduceOp = op
329
            work = pg.reduce(xs, opts)
330
            work.wait()
331

332
        # for every root tensor
333
        for rt in range(self.world_size):
334
            tensors = [torch.tensor([self.rank + 1]).cuda(local_device_id)]
335

336
            reduce(tensors, rt, 0)
337

338
            if self.rank == rt:
339
                self.assertEqual(
340
                    torch.tensor([self.world_size * (self.world_size + 1) // 2]),
341
                    tensors[0],
342
                )
343
            else:
344
                self.assertEqual(
345
                    torch.tensor([self.rank + 1]),
346
                    tensors[0],
347
                )
348

349
            for op, err in zip(
350
                (c10d.ReduceOp.BAND, c10d.ReduceOp.BOR, c10d.ReduceOp.BXOR),
351
                ("ReduceOp.BAND", "ReduceOp.BOR", "ReduceOp.BXOR"),
352
            ):
353
                with self.assertRaisesRegex(
354
                    ValueError, "Cannot use " + err + " with NCCL"
355
                ):
356
                    reduce(tensors, self.rank, rt, op)
357

358
            # Premul sum
359
            if torch.cuda.nccl.version() >= (2, 11, 1):
360
                for factor in (3.0, torch.tensor([5.0], device=local_device_id)):
361
                    if isinstance(factor, torch.Tensor):
362
                        factor_ref = factor.cpu().item()
363
                    else:
364
                        factor_ref = factor
365
                    float_tensors = [
366
                        torch.tensor(
367
                            [self.rank + 1.0], device=f"cuda:{local_device_id}"
368
                        )
369
                    ]
370
                    float_tensors_ref = [
371
                        torch.tensor(
372
                            [(self.rank + 1.0) * factor_ref],
373
                            device=f"cuda:{local_device_id}",
374
                        )
375
                    ]
376

377
                    reduce(float_tensors_ref, rt, 0)
378
                    reduce(float_tensors, rt, 0, c10d._make_nccl_premul_sum(factor))
379
                    if self.rank == rt:
380
                        self.assertEqual(float_tensors_ref[0], float_tensors[0])
381

382
    @requires_nccl()
383
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
384
    def test_allgather_ops(self):
385
        pg = self.pg
386
        local_device_ids = self.rank_to_GPU[self.rank]
387

388
        def allgather(output_ts, input_ts):
389
            work = pg.allgather(output_ts, input_ts)
390
            return work.wait()
391

392
        tensors = [torch.empty(2, 2).fill_(2).cuda(device=i) for i in local_device_ids]
393
        output_tensors = []
394
        expected_output = []
395

396
        output_per_gpu = (
397
            [torch.empty(2, 2).fill_(-1)] * len(local_device_ids) * self.world_size
398
        )
399
        expected_per_gpu = (
400
            [torch.empty(2, 2).fill_(2)] * len(local_device_ids) * self.world_size
401
        )
402

403
        for gpu in local_device_ids:
404
            output_tensors.append([t.cuda(device=gpu) for t in output_per_gpu])
405
            expected_output.append([t.cuda(device=gpu) for t in expected_per_gpu])
406

407
        result = allgather(output_tensors, tensors)
408

409
        # Verification
410
        self.assertEqual(output_tensors, expected_output)
411

412
    @requires_nccl()
413
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
414
    def test_allgather_base_ops(self):
415
        pg = self.pg
416
        local_device_id = self.rank_to_GPU[self.rank][0]
417

418
        def allgather_base(output_t, input_t):
419
            work = pg._allgather_base(output_t, input_t)
420
            work.wait()
421

422
        # allgather_base is GPU number agnostic.
423
        # Each rank contribute one tensor regardless of GPU counts
424
        tensor = torch.tensor([self.rank]).cuda(local_device_id)
425
        output_t = torch.empty((self.world_size), dtype=tensor.dtype).cuda(
426
            local_device_id
427
        )
428

429
        allgather_base(output_t, tensor)
430

431
        # Verification
432
        self.assertEqual(torch.arange(self.world_size), output_t)
433

434
    @requires_nccl()
435
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
436
    def test_allgather_base_basics(self):
437
        pg = self.pg
438
        local_device_id = self.rank_to_GPU[self.rank][0]
439

440
        def allgather_base(output_t, input_t):
441
            work = pg._allgather_base(output_t, input_t)
442
            work.wait()
443

444
        # anticipate an error
445
        with self.assertRaisesRegex(
446
            ValueError,
447
            "output tensor size must be equal to world_size times input tensor size",
448
        ):
449
            tensor = torch.tensor([self.rank]).cuda(local_device_id)
450
            output_t = torch.empty((self.world_size + 1), dtype=tensor.dtype).cuda(
451
                local_device_id
452
            )
453
            # fails the check because output_t is not correctly sized
454
            allgather_base(output_t, tensor)
455

456
        # anticipate an error
457
        with self.assertRaisesRegex(
458
            TypeError, "output tensor must have the same type as input tensor"
459
        ):
460
            tensor = torch.tensor([self.rank], dtype=torch.float).cuda(local_device_id)
461
            output_t = torch.empty((self.world_size + 1), dtype=torch.long).cuda(
462
                local_device_id
463
            )
464
            # fails the check because the dtype is different
465
            allgather_base(output_t, tensor)
466

467
    @requires_nccl()
468
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
469
    def test_gather_ops(self):
470
        pg = self.pg
471
        local_device_ids = self.rank_to_GPU[self.rank]
472
        num_gpus = len(local_device_ids)
473

474
        def gather(output_t, input_t, rootRank):
475
            opts = c10d.GatherOptions()
476
            opts.rootRank = rootRank
477
            if rootRank == self.rank:
478
                work = pg.gather(output_t, input_t, opts)
479
            else:
480
                work = pg.gather([], input_t, opts)
481
            work.wait()
482

483
        # init input
484
        tensors = []
485
        for device_id in local_device_ids:
486
            tensors.append(torch.tensor([self.rank]).cuda(device_id))
487

488
        # init output
489
        output_ts = []
490
        for idx in range(num_gpus):
491
            gpu_idx = local_device_ids[idx]
492
            output_ts.append([])
493
            for rank in range(self.world_size):
494
                output_ts[idx].append(torch.tensor([-1]).cuda(gpu_idx))
495

496
        expected = [[torch.tensor([rank]) for rank in range(self.world_size)]]
497
        for rank in range(self.world_size):
498
            gather(output_ts, tensors, rank)
499
            if rank == self.rank:
500
                self.assertEqual(expected, output_ts)
501

502
    @requires_nccl()
503
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
504
    def test_gather_stress(self):
505
        pg = self.pg
506
        local_device_ids = self.rank_to_GPU[self.rank]
507
        num_gpus = len(local_device_ids)
508

509
        def gather(output_t, input_t, rootRank):
510
            opts = c10d.GatherOptions()
511
            opts.rootRank = rootRank
512
            if rootRank == self.rank:
513
                work = pg.gather(output_t, input_t, opts)
514
            else:
515
                work = pg.gather([], input_t, opts)
516
            work.wait()
517

518
        stress_length = 1000
519

520
        # init input
521
        tensors = []
522
        for i in range(stress_length):
523
            tensors.append([])
524
            for device_id in local_device_ids:
525
                tensors[i].append(torch.tensor([self.rank]).cuda(device_id))
526

527
        # init output
528
        output_ts = []
529
        for i in range(stress_length):
530
            output_ts.append([[] for _ in range(num_gpus)])
531
            for idx, ls in enumerate(output_ts[i]):
532
                gpu_idx = local_device_ids[idx]
533
                for _ in range(self.world_size):
534
                    ls.append(torch.tensor([-1]).cuda(gpu_idx))
535

536
        expected = [[torch.tensor([rank]) for rank in range(self.world_size)]]
537
        for i in range(stress_length):
538
            for rank in range(self.world_size):
539
                gather(output_ts[i], tensors[i], rank)
540
                # Verification
541
                if rank == self.rank:
542
                    self.assertEqual(output_ts[i], expected)
543

544
    @requires_nccl()
545
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
546
    def test_gather_checks(self):
547
        pg = self.pg
548
        device_id = self.rank_to_GPU[self.rank][0]
549

550
        # init input
551
        tensor = torch.tensor([self.rank]).cuda(device_id)
552

553
        # init output
554
        output_ts = []
555
        for rank in range(self.world_size):
556
            output_ts.append(torch.tensor([-1]).cuda(device_id))
557

558
        with self.assertRaisesRegex(ValueError, "invalid root rank"):
559
            opts = c10d.GatherOptions()
560
            opts.rootRank = -1
561
            pg.gather([output_ts], [tensor], opts)
562

563
        with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
564
            pg.gather([output_ts], [tensor], 0)
565

566
        with self.assertRaisesRegex(ValueError, "invalid root rank"):
567
            opts = c10d.GatherOptions()
568
            opts.rootRank = self.world_size
569
            pg.gather([output_ts], [tensor], opts)
570

571
        with self.assertRaisesRegex(
572
            # throws error message from dispatcher
573
            RuntimeError,
574
            "There were no tensor arguments to this function",
575
        ):
576
            opts = c10d.GatherOptions()
577
            opts.rootRank = 0
578
            pg.gather([output_ts], [], opts)
579

580
    @requires_nccl()
581
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
582
    def test_scatter_ops(self):
583
        pg = self.pg
584
        local_device_ids = self.rank_to_GPU[self.rank]
585
        num_gpus = len(local_device_ids)
586

587
        def scatter(output_t, input_t, rootRank):
588
            opts = c10d.ScatterOptions()
589
            opts.rootRank = rootRank
590
            if rootRank == self.rank:
591
                work = pg.scatter(output_t, input_t, opts)
592
            else:
593
                work = pg.scatter(output_t, [], opts)
594
            work.wait()
595

596
        # init output
597
        tensors = []
598
        for device_id in local_device_ids:
599
            tensors.append(torch.tensor([-1]).cuda(device_id))
600

601
        # init input
602
        scatter_list = []
603
        for idx in range(num_gpus):
604
            gpu_idx = local_device_ids[idx]
605
            scatter_list.append([])
606
            for rank in range(self.world_size):
607
                scatter_list[idx].append(torch.tensor([rank]).cuda(gpu_idx))
608

609
        # test each rank to scatter
610
        expected = [torch.tensor([self.rank])]
611
        for rank in range(self.world_size):
612
            scatter(tensors, scatter_list, rank)
613
            self.assertEqual(expected, tensors)
614

615
    @requires_nccl()
616
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
617
    def test_scatter_stress(self):
618
        pg = self.pg
619
        local_device_ids = self.rank_to_GPU[self.rank]
620
        num_gpus = len(local_device_ids)
621

622
        def scatter(output_t, input_t, rootRank):
623
            opts = c10d.ScatterOptions()
624
            opts.rootRank = rootRank
625
            if rootRank == self.rank:
626
                work = pg.scatter(output_t, input_t, opts)
627
            else:
628
                work = pg.scatter(output_t, [], opts)
629
            work.wait()
630

631
        stress_length = 1000
632

633
        # init output
634
        tensors = []
635
        for i in range(stress_length):
636
            tensors.append([])
637
            for device_id in local_device_ids:
638
                tensors[i].append(torch.tensor([-1]).cuda(device_id))
639

640
        # init input
641
        scatter_list = []
642
        for i in range(stress_length):
643
            scatter_list.append([[] for _ in range(num_gpus)])
644
            for idx, ls in enumerate(scatter_list[i]):
645
                gpu_idx = local_device_ids[idx]
646
                for rank in range(self.world_size):
647
                    ls.append(torch.tensor([rank]).cuda(gpu_idx))
648

649
        # test each rank to scatter
650
        expected = [torch.tensor([self.rank])]
651
        for i in range(stress_length):
652
            for rank in range(self.world_size):
653
                scatter(tensors[i], scatter_list[i], rank)
654
                # Verification
655
                self.assertEqual(tensors[i], expected)
656

657
    @requires_nccl()
658
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
659
    def test_scatter_checks(self):
660
        pg = self.pg
661
        local_device_ids = self.rank_to_GPU[self.rank]
662
        num_gpus = len(local_device_ids)
663

664
        # init output
665
        tensors = []
666
        for device_id in local_device_ids:
667
            tensors.append(torch.tensor([-1]).cuda(device_id))
668

669
        # init input
670
        scatter_list = []
671
        for idx in range(num_gpus):
672
            gpu_idx = local_device_ids[idx]
673
            scatter_list.append([])
674
            for rank in range(self.world_size):
675
                scatter_list[idx].append(torch.tensor([rank]).cuda(gpu_idx))
676

677
        with self.assertRaisesRegex(ValueError, "invalid root rank"):
678
            opts = c10d.ScatterOptions()
679
            opts.rootRank = -1
680
            pg.scatter(tensors, scatter_list, opts)
681

682
        with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
683
            pg.scatter(tensors, scatter_list, 0)
684

685
        with self.assertRaisesRegex(ValueError, "invalid root rank"):
686
            opts = c10d.ScatterOptions()
687
            opts.rootRank = self.world_size
688
            pg.scatter(tensors, scatter_list, opts)
689

690
        with self.assertRaisesRegex(
691
            # throws error message from dispatcher
692
            RuntimeError,
693
            "There were no tensor arguments to this function",
694
        ):
695
            opts = c10d.ScatterOptions()
696
            opts.rootRank = 0
697
            pg.scatter([], scatter_list, opts)
698

699
    @requires_nccl()
700
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
701
    def test_reduce_scatter_base_basics(self):
702
        pg = self.pg
703
        local_device_id = self.rank_to_GPU[self.rank][0]
704

705
        def reduce_scatter_base(output_t, input_t):
706
            work = pg._reduce_scatter_base(output_t, input_t)
707
            work.wait()
708

709
        # anticipate an error
710
        with self.assertRaisesRegex(
711
            ValueError,
712
            "input tensor must be the same size as output size times world size",
713
        ):
714
            input_t = torch.tensor([self.rank]).cuda(local_device_id)
715
            output_t = torch.empty((self.world_size + 1), dtype=input_t.dtype).cuda(
716
                local_device_id
717
            )
718
            # fails the check because output_t is not correctly sized
719
            reduce_scatter_base(output_t, input_t)
720

721
        # anticipate an error
722
        with self.assertRaisesRegex(
723
            TypeError, "input tensor must be the same type as the output tensor."
724
        ):
725
            tensor = torch.tensor([self.rank], dtype=torch.float).cuda(local_device_id)
726
            output_t = torch.empty((self.world_size + 1), dtype=torch.long).cuda(
727
                local_device_id
728
            )
729
            # fails the check because the dtype is different
730
            reduce_scatter_base(output_t, tensor)
731

732
    @requires_nccl()
733
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
734
    def test_reduce_scatter_ops(self):
735
        pg = self.pg
736
        local_device_ids = self.rank_to_GPU[self.rank]
737
        num_gpus = len(local_device_ids)
738

739
        def reduce_scatter(outputs, input_lists, op):
740
            opts = c10d.ReduceScatterOptions()
741
            opts.reduceOp = op
742
            work = pg.reduce_scatter(outputs, input_lists, opts)
743
            work.wait()
744

745
        output = [torch.tensor([0]).cuda(i) for i in local_device_ids]
746

747
        #  GPU/rank
748
        #   0         [1], [2], [3], [4]
749
        #   1         [2], [3], [4], [5]
750
        #   2         [3], [4], [5], [6]
751
        #   3         [4], [5], [6], [7]
752

753
        # Sum
754
        tensor_lists = []
755
        input_per_gpu = []
756

757
        for i in range(self.world_size):
758
            input_per_gpu.append(torch.tensor([self.rank + i + 1]))
759

760
        for gpu in local_device_ids:
761
            tensor_lists.append([t.cuda(device=gpu) for t in input_per_gpu])
762

763
        reduce_scatter(output, tensor_lists, c10d.ReduceOp.SUM)
764

765
        for i in range(num_gpus):
766
            expected = torch.tensor(
767
                [
768
                    (1 + self.world_size) * self.world_size // 2
769
                    + self.world_size * self.rank
770
                ]
771
            )
772

773
            self.assertEqual(expected, output[i])
774

775
        # Min
776
        reduce_scatter(output, tensor_lists, c10d.ReduceOp.MIN)
777

778
        for i in range(num_gpus):
779
            expected = torch.tensor([self.rank + 1 + i])
780
            self.assertEqual(expected, output[i])
781

782
        # Max
783
        reduce_scatter(output, tensor_lists, c10d.ReduceOp.MAX)
784

785
        for i in range(num_gpus):
786
            expected = torch.tensor([self.rank + self.world_size + i])
787
            self.assertEqual(expected, output[i])
788

789
        # Product
790
        reduce_scatter(output, tensor_lists, c10d.ReduceOp.PRODUCT)
791

792
        # math package don't have math.perm until python 3.8, so
793
        # we implement a naive version here.
794
        def perm(n, k):
795
            prod_val = n
796
            for val in range(n - k + 1, n):
797
                prod_val *= val
798
            return prod_val
799

800
        for i in range(num_gpus):
801
            prod_val = perm(self.rank + self.world_size, self.world_size)
802

803
            expected = torch.tensor([prod_val])
804
            self.assertEqual(expected, output[i])
805

806
        # Test the input params overridden scenarios, aka, when the input is
807
        # a list and output is just one tensor.
808
        # Sum
809
        output_tensor = torch.empty_like(input_per_gpu[0][0]).cuda(self.rank)
810
        input_list = [tensor[0].cuda(self.rank) for tensor in input_per_gpu]
811
        pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.SUM).wait()
812
        expected = torch.tensor(
813
            (1 + self.world_size) * self.world_size // 2 + self.world_size * self.rank
814
        )
815
        self.assertEqual(expected, output_tensor)
816

817
        # Min
818
        pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.MIN).wait()
819
        expected = torch.tensor(self.rank + 1)
820
        self.assertEqual(expected, output_tensor)
821

822
        # Max
823
        pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.MAX).wait()
824
        expected = torch.tensor(self.rank + self.world_size)
825
        self.assertEqual(expected, output_tensor)
826

827
        # Product
828
        pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.PRODUCT).wait()
829
        prod_val = self.rank + 1
830
        for k in range(1, self.world_size):
831
            prod_val = prod_val * (self.rank + 1 + k)
832
        expected = torch.tensor(prod_val)
833
        self.assertEqual(expected, output_tensor)
834

835
        if torch.cuda.nccl.version() >= (2, 11, 1):
836
            for factor in (3.0, torch.tensor([5.0], device=self.rank)):
837
                if isinstance(factor, torch.Tensor):
838
                    factor_ref = factor.cpu().item()
839
                else:
840
                    factor_ref = factor
841
                output = [t.float() for t in output]
842
                tensor_lists = [[t.float() for t in tl] for tl in tensor_lists]
843
                output_ref = [t.float() for t in output]
844
                tensor_lists_ref = [
845
                    [t.float() * factor_ref for t in tl] for tl in tensor_lists
846
                ]
847
                reduce_scatter(output, tensor_lists, c10d._make_nccl_premul_sum(factor))
848
                reduce_scatter(output_ref, tensor_lists_ref, c10d.ReduceOp.SUM)
849
                self.assertEqual(output_ref, output)
850

851
    @requires_nccl()
852
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
853
    def test_reduce_scatter_base_ops(self):
854
        pg = self.pg
855
        local_device_id = self.rank_to_GPU[self.rank][0]
856

857
        def reduce_scatter_base(output_t, input_t):
858
            work = pg._reduce_scatter_base(output_t, input_t)
859
            work.wait()
860

861
        # reduce_scatter_base is GPU number agnostic.
862
        # Each rank contribute one tensor regardless of GPU counts
863
        output_t = torch.empty([1]).cuda(local_device_id)
864
        tensor = torch.arange(self.world_size, dtype=output_t.dtype).cuda(
865
            local_device_id
866
        )
867

868
        reduce_scatter_base(output_t, tensor)
869

870
        # Verification
871
        self.assertEqual(output_t[0], self.rank * self.world_size)
872

873
    @requires_nccl()
874
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
875
    def test_barrier(self):
876
        pg = self.pg
877
        local_device_ids = self.rank_to_GPU[self.rank]
878

879
        def allreduce(tensors):
880
            opts = c10d.AllreduceOptions()
881
            work = pg.allreduce(tensors, opts)
882
            return work
883

884
        # Making the collective to operate on
885
        # 1, 2, 3, 4, .... len(local_device_ids) GPUs
886
        tensors_list = [[] for _ in range(len(local_device_ids))]
887

888
        for i in range(1, len(local_device_ids) + 1):
889
            for j in range(i):
890
                tensors_list[i - 1].append(
891
                    torch.tensor([j + 1]).cuda(local_device_ids[j])
892
                )
893

894
        works = []
895
        for tensors in tensors_list:
896
            work = allreduce(tensors)
897
            works.append(work)
898

899
        # Barrier will ensure that all previous work is completed
900
        pg.barrier().wait()
901

902
        for i in range(1, len(local_device_ids) + 1):
903
            for j in range(i):
904
                self.assertEqual(
905
                    torch.tensor([(j + 1) * self.world_size]), tensors_list[i - 1][j]
906
                )
907

908
    @requires_nccl()
909
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
910
    def test_send_recv(self):
911
        pg = self.pg
912
        device = self.rank_to_GPU[self.rank][0]
913

914
        # Generate the same random tensor
915
        torch.manual_seed(0)
916
        send_tensor = torch.rand(10, 10, device=device)
917
        if self.rank == 0:
918
            dist.send(send_tensor, 1)
919
        if self.rank == 1:
920
            recv_tensor = torch.rand(10, 10, device=device)
921
            dist.recv(recv_tensor, 0)
922
            self.assertEqual(send_tensor, recv_tensor)
923

924
    @requires_nccl()
925
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
926
    def test_send_recv_complex(self):
927
        pg = self.pg
928
        device = self.rank_to_GPU[self.rank][0]
929

930
        # Generate the same random tensor
931
        torch.manual_seed(0)
932
        send_tensor = torch.rand(10, 10, dtype=torch.cfloat, device=device)
933
        if self.rank == 0:
934
            dist.send(send_tensor, 1)
935
        if self.rank == 1:
936
            recv_tensor = torch.rand(10, 10, dtype=torch.cfloat, device=device)
937
            dist.recv(recv_tensor, 0)
938
            self.assertEqual(send_tensor, recv_tensor)
939

940
    @requires_nccl()
941
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
942
    def test_send_recv_object_list(self):
943
        device = self.rank_to_GPU[self.rank][0]
944

945
        val = 99 if self.rank == 0 else None
946
        object_list = [val] * self.world_size
947
        if self.rank == 0:
948
            dist.send_object_list(object_list, 1, device=device)
949
        if self.rank == 1:
950
            dist.recv_object_list(object_list, 0, device=device)
951
            self.assertEqual(object_list[0], 99)
952

953
    @requires_nccl()
954
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
955
    def test_tensor_register_hook(self):
956
        os.environ["TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK"] = "1"
957

958
        pg = self.pg
959
        local_device_id = self.rank_to_GPU[self.rank][0]
960

961
        def allgather_base(output_t, input_t):
962
            work = pg._allgather_base(output_t, input_t)
963
            work.wait()
964

965
        # allgather_base is GPU number agnostic.
966
        # Each rank contribute one tensor regardless of GPU counts
967
        tensor = torch.tensor([self.rank]).cuda(local_device_id)
968
        output_t = torch.empty((self.world_size), dtype=tensor.dtype).cuda(
969
            local_device_id
970
        )
971

972
        allgather_base(output_t, tensor)
973

974
        # Verification
975
        self.assertEqual(torch.arange(self.world_size), output_t)
976

977
        # Unset env
978
        del os.environ["TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK"]
979

980

981
if __name__ == "__main__":
982
    rank = int(os.getenv("RANK", -1))
983
    world_size = int(os.getenv("WORLD_SIZE", 2))
984

985
    if rank != -1:
986
        # Launched with torchrun or other multi-proc launchers. Directly run the test.
987
        ProcessGroupNCCLOpTest.run_rank(rank, world_size)
988
    else:
989
        # Launched as a single process. Spawn subprocess to run the tests.
990
        # Also need a rendezvous file for `init_process_group` purpose.
991
        rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
992
        torch.multiprocessing.spawn(
993
            ProcessGroupNCCLOpTest.run_rank,
994
            nprocs=world_size,
995
            args=(world_size, rdvz_file),
996
        )
997

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.