pytorch
996 строк · 35.4 Кб
1# Owner(s): ["oncall: distributed"]
2# This test file contains positive tests for c10d with NCCL backend.
3# During the test, it is expected that ProcessGroup will not be aborted, destroyed or incur fatal error.
4# Please be mindful of this when adding tests here.
5# If you need to add tests for group creation, abort or destroy, please add tests in test_c10d_nccl.py.
6
7# There are two ways to launch tests in this file:
8# 1. Run this file directly with `python test_c10d_ops_nccl.py`
9# 2. Use multi-process launcher, e.g. `torchrun --standalone --nproc-per-node 2 test_c10d_ops_nccl.py`
10
11import math
12import os
13import sys
14import tempfile
15
16import torch
17import torch.distributed as c10d
18
19
20if not c10d.is_available() or not c10d.is_nccl_available():
21print("c10d NCCL not available, skipping tests", file=sys.stderr)
22sys.exit(0)
23
24
25import torch.distributed as dist
26from torch.testing._internal.common_cuda import TEST_MULTIGPU
27from torch.testing._internal.common_distributed import (
28init_multigpu_helper,
29MultiProcContinousTest,
30requires_nccl,
31)
32from torch.testing._internal.common_utils import (
33skip_but_pass_in_sandcastle_if,
34skipIfRocm,
35TEST_WITH_DEV_DBG_ASAN,
36)
37
38
39if TEST_WITH_DEV_DBG_ASAN:
40print(
41"Skip ASAN as torch + multiprocessing spawn have known issues", file=sys.stderr
42)
43sys.exit(0)
44
45
46class ProcessGroupNCCLOpTest(MultiProcContinousTest):
47@classmethod
48def backend_str(cls) -> str:
49return "nccl"
50
51@classmethod
52def opts(cls, high_priority_stream=False):
53opts = c10d.ProcessGroupNCCL.Options()
54opts.is_high_priority_stream = high_priority_stream
55return opts
56
57@property
58def rank_to_GPU(self):
59# return rank to GPU map
60return init_multigpu_helper(self.world_size, "nccl")
61
62@requires_nccl()
63@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
64def test_empty_tensors(self):
65pg = self.pg
66local_device_idx = self.rank_to_GPU[self.rank][0]
67
68xs = [torch.FloatTensor([]).cuda(local_device_idx)]
69pg.broadcast(xs).wait()
70self.assertEqual(0, xs[0].numel())
71
72pg.allreduce(xs).wait()
73self.assertEqual(0, xs[0].numel())
74
75pg.reduce(xs).wait()
76self.assertEqual(0, xs[0].numel())
77
78ys = [
79[
80torch.FloatTensor([]).cuda(local_device_idx)
81for _ in range(self.world_size)
82]
83]
84pg.allgather(ys, xs).wait()
85for y in ys[0]:
86self.assertEqual(0, y.numel())
87
88ys = [torch.FloatTensor([]).cuda(local_device_idx)]
89xs = [
90[
91torch.FloatTensor([]).cuda(local_device_idx)
92for _ in range(self.world_size)
93]
94]
95pg.reduce_scatter(ys, xs).wait()
96self.assertEqual(0, ys[0].numel())
97
98@requires_nccl()
99@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
100def test_broadcast_ops(self):
101pg = self.pg
102
103def broadcast(xs, rootRank, rootTensor):
104opts = c10d.BroadcastOptions()
105opts.rootRank = rootRank
106opts.rootTensor = rootTensor
107work = pg.broadcast(xs, opts)
108work.wait()
109return xs
110
111# Every rank is root once
112for i in range(self.world_size):
113# Run with 1 input tensor
114x = torch.tensor([self.rank]).cuda(self.rank_to_GPU[self.rank][0])
115output = broadcast([x], i, 0)
116self.assertEqual(torch.tensor([i]), output[0])
117
118expected_tensor = torch.empty([i + 1, i + 1]).fill_(i + 1)
119xs = [
120torch.empty([i + 1, i + 1]).fill_(-1).cuda(device=device_idx)
121for device_idx in self.rank_to_GPU[self.rank]
122]
123
124# test with multiple input tensors (multiple gpu in one rank)
125for j in range(len(xs)):
126if self.rank == i:
127xs[j] = expected_tensor.cuda(device=self.rank_to_GPU[self.rank][j])
128
129broadcast(xs, i, j)
130
131for tensor in xs:
132self.assertEqual(tensor, expected_tensor)
133
134@requires_nccl()
135@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
136def test_sparse_allreduce_ops(self):
137pg = self.pg
138
139indices = torch.tensor([[0, 1]])
140values = torch.tensor([[1, 2, 0], [4, 0, 6]])
141sparse_tensor = torch.sparse_coo_tensor(indices, values, size=(2, 3)).to(
142self.rank
143)
144
145# sparse allreduce call is wrapped in a try catch since the c10d API is only available in the nccl experimental branch
146try:
147tensor_list = [sparse_tensor]
148work = pg.allreduce(tensor_list)
149work.wait()
150
151# tensor_list is a list of size 1, with the allreduce output as a dense tensor
152a = torch.tensor([[2, 4, 0], [8, 0, 12]]).to(self.rank)
153self.assertEqual(tensor_list[0], a)
154except RuntimeError as e:
155if "NCCL does not support all_reduce with sparse tensors" in str(e):
156pass
157else:
158# Rethrow the exception if it's a different error
159raise
160
161@requires_nccl()
162@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
163def test_allreduce_ops(self):
164device_count = torch.cuda.device_count()
165pg = self.pg
166local_device_id = self.rank_to_GPU[self.rank][0]
167
168def allreduce(tensors, op):
169opts = c10d.AllreduceOptions()
170opts.reduceOp = op
171work = pg.allreduce(tensors, opts)
172work.wait()
173
174# Sum
175tensors = [torch.tensor([self.rank + 1]).cuda(local_device_id)]
176
177allreduce(tensors, c10d.ReduceOp.SUM)
178
179ndev = self.world_size
180self.assertEqual(
181torch.tensor([ndev * (ndev + 1) // 2]),
182tensors[0],
183)
184
185# Avg (only available for NCCL 2.10+)
186if torch.cuda.nccl.version() >= (2, 10, 0):
187tensors = [torch.tensor([self.rank + 1.0]).cuda(local_device_id)]
188
189allreduce(tensors, c10d.ReduceOp.AVG)
190ndev = self.world_size
191self.assertEqual(
192torch.tensor([ndev * (ndev + 1.0) / (2.0 * ndev)]),
193tensors[0],
194)
195
196# Premul Sum
197if torch.cuda.nccl.version() >= (2, 11, 1):
198for dtype in torch.half, torch.float, torch.double:
199for factor in (
2003.0,
201torch.tensor([5.0], device=local_device_id, dtype=dtype),
202):
203tensors = [
204torch.tensor([self.rank + 1])
205.cuda(local_device_id)
206.to(dtype=dtype)
207]
208
209allreduce(tensors, c10d._make_nccl_premul_sum(factor))
210
211self.assertEqual(
212factor
213* torch.tensor(
214[self.world_size * (self.world_size + 1) / 2],
215dtype=dtype,
216device=local_device_id,
217),
218tensors[0],
219)
220
221# Product
222tensors = [torch.tensor([self.rank + 1]).cuda(local_device_id)]
223
224allreduce(tensors, c10d.ReduceOp.PRODUCT)
225self.assertEqual(torch.tensor([math.factorial(self.world_size)]), tensors[0])
226
227# Min
228tensors = [torch.tensor([self.rank + 1]).cuda(local_device_id)]
229
230allreduce(tensors, c10d.ReduceOp.MIN)
231self.assertEqual(torch.tensor([1]), tensors[0])
232
233# Max
234tensors = [torch.tensor([self.rank + 1]).cuda(local_device_id)]
235
236allreduce(tensors, c10d.ReduceOp.MAX)
237self.assertEqual(torch.tensor([self.world_size]), tensors[0])
238
239for op, err in zip(
240(c10d.ReduceOp.BAND, c10d.ReduceOp.BOR, c10d.ReduceOp.BXOR),
241("ReduceOp.BAND", "ReduceOp.BOR", "ReduceOp.BXOR"),
242):
243with self.assertRaisesRegex(ValueError, "Cannot use " + err + " with NCCL"):
244allreduce(tensors, op)
245
246@requires_nccl()
247@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
248def test_alltoall_ops_with_cudafree_race(self):
249pg = self.pg
250opts = c10d.AllToAllOptions()
251local_device = f"cuda:{self.rank_to_GPU[self.rank][0]}"
252torch.cuda.set_device(local_device)
253input = torch.rand(1000, 1000, device=local_device)
254output = torch.rand(1000, 1000, device=local_device)
255race_tensors = []
256# create some tensors to race with alltoall collective
257for _ in range(10):
258tmp = []
259for i in range(5):
260tmp.append(torch.rand(10 ** (3 + i), device=local_device))
261race_tensors.append(tmp)
262
263for i in range(10):
264race_tensors.pop()
265work = pg.alltoall_base(output, input, [], [], opts)
266# this triggers cudaFree
267torch.cuda.empty_cache()
268work.wait()
269torch.cuda.synchronize(device=local_device)
270
271@requires_nccl()
272@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
273def test_allreduce_in_cudagraph(self):
274pg = self.pg
275local_device_idx = self.rank_to_GPU[self.rank][0]
276with torch.cuda.device(local_device_idx):
277xs = [torch.FloatTensor([1]).cuda(local_device_idx)]
278
279# single warmup
280pg.allreduce(xs).wait()
281self.assertEqual(xs[0].item(), 2)
282
283graph = torch.cuda.CUDAGraph()
284with torch.cuda.graph(graph):
285pg.allreduce(xs).wait()
286self.assertEqual(xs[0].item(), 2)
287
288graph.replay()
289graph.replay()
290self.assertEqual(xs[0].item(), 8)
291
292@requires_nccl()
293@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
294@skipIfRocm()
295def test_nccl_watchdog_cudagraph(self):
296# test that the watchdog does not crash graphs with disallowed event query
297pg = self.pg
298rank = self.rank_to_GPU[self.rank][0]
299with torch.cuda.device(rank):
300for i in range(10):
301xs = [torch.FloatTensor([1]).cuda(rank)]
302ys = [torch.FloatTensor([4]).cuda(rank)]
303for _ in range(30):
304pg.allreduce(xs[0]).wait()
305
306graph = torch.cuda.CUDAGraph()
307with torch.cuda.graph(graph):
308xs[0] += 0.0
309pg.allreduce(xs[0]).wait()
310pg.allreduce(xs[0]).wait()
311pg.allreduce(xs[0]).wait()
312xs[0] += 0.0
313
314for _ in range(100):
315graph.replay()
316
317@requires_nccl()
318@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
319def test_reduce_ops(self):
320pg = self.pg
321local_device_id = self.rank_to_GPU[self.rank][0]
322
323def reduce(xs, rootRank, rootTensor, op=None):
324opts = c10d.ReduceOptions()
325opts.rootRank = rootRank
326opts.rootTensor = rootTensor
327if op:
328opts.reduceOp = op
329work = pg.reduce(xs, opts)
330work.wait()
331
332# for every root tensor
333for rt in range(self.world_size):
334tensors = [torch.tensor([self.rank + 1]).cuda(local_device_id)]
335
336reduce(tensors, rt, 0)
337
338if self.rank == rt:
339self.assertEqual(
340torch.tensor([self.world_size * (self.world_size + 1) // 2]),
341tensors[0],
342)
343else:
344self.assertEqual(
345torch.tensor([self.rank + 1]),
346tensors[0],
347)
348
349for op, err in zip(
350(c10d.ReduceOp.BAND, c10d.ReduceOp.BOR, c10d.ReduceOp.BXOR),
351("ReduceOp.BAND", "ReduceOp.BOR", "ReduceOp.BXOR"),
352):
353with self.assertRaisesRegex(
354ValueError, "Cannot use " + err + " with NCCL"
355):
356reduce(tensors, self.rank, rt, op)
357
358# Premul sum
359if torch.cuda.nccl.version() >= (2, 11, 1):
360for factor in (3.0, torch.tensor([5.0], device=local_device_id)):
361if isinstance(factor, torch.Tensor):
362factor_ref = factor.cpu().item()
363else:
364factor_ref = factor
365float_tensors = [
366torch.tensor(
367[self.rank + 1.0], device=f"cuda:{local_device_id}"
368)
369]
370float_tensors_ref = [
371torch.tensor(
372[(self.rank + 1.0) * factor_ref],
373device=f"cuda:{local_device_id}",
374)
375]
376
377reduce(float_tensors_ref, rt, 0)
378reduce(float_tensors, rt, 0, c10d._make_nccl_premul_sum(factor))
379if self.rank == rt:
380self.assertEqual(float_tensors_ref[0], float_tensors[0])
381
382@requires_nccl()
383@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
384def test_allgather_ops(self):
385pg = self.pg
386local_device_ids = self.rank_to_GPU[self.rank]
387
388def allgather(output_ts, input_ts):
389work = pg.allgather(output_ts, input_ts)
390return work.wait()
391
392tensors = [torch.empty(2, 2).fill_(2).cuda(device=i) for i in local_device_ids]
393output_tensors = []
394expected_output = []
395
396output_per_gpu = (
397[torch.empty(2, 2).fill_(-1)] * len(local_device_ids) * self.world_size
398)
399expected_per_gpu = (
400[torch.empty(2, 2).fill_(2)] * len(local_device_ids) * self.world_size
401)
402
403for gpu in local_device_ids:
404output_tensors.append([t.cuda(device=gpu) for t in output_per_gpu])
405expected_output.append([t.cuda(device=gpu) for t in expected_per_gpu])
406
407result = allgather(output_tensors, tensors)
408
409# Verification
410self.assertEqual(output_tensors, expected_output)
411
412@requires_nccl()
413@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
414def test_allgather_base_ops(self):
415pg = self.pg
416local_device_id = self.rank_to_GPU[self.rank][0]
417
418def allgather_base(output_t, input_t):
419work = pg._allgather_base(output_t, input_t)
420work.wait()
421
422# allgather_base is GPU number agnostic.
423# Each rank contribute one tensor regardless of GPU counts
424tensor = torch.tensor([self.rank]).cuda(local_device_id)
425output_t = torch.empty((self.world_size), dtype=tensor.dtype).cuda(
426local_device_id
427)
428
429allgather_base(output_t, tensor)
430
431# Verification
432self.assertEqual(torch.arange(self.world_size), output_t)
433
434@requires_nccl()
435@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
436def test_allgather_base_basics(self):
437pg = self.pg
438local_device_id = self.rank_to_GPU[self.rank][0]
439
440def allgather_base(output_t, input_t):
441work = pg._allgather_base(output_t, input_t)
442work.wait()
443
444# anticipate an error
445with self.assertRaisesRegex(
446ValueError,
447"output tensor size must be equal to world_size times input tensor size",
448):
449tensor = torch.tensor([self.rank]).cuda(local_device_id)
450output_t = torch.empty((self.world_size + 1), dtype=tensor.dtype).cuda(
451local_device_id
452)
453# fails the check because output_t is not correctly sized
454allgather_base(output_t, tensor)
455
456# anticipate an error
457with self.assertRaisesRegex(
458TypeError, "output tensor must have the same type as input tensor"
459):
460tensor = torch.tensor([self.rank], dtype=torch.float).cuda(local_device_id)
461output_t = torch.empty((self.world_size + 1), dtype=torch.long).cuda(
462local_device_id
463)
464# fails the check because the dtype is different
465allgather_base(output_t, tensor)
466
467@requires_nccl()
468@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
469def test_gather_ops(self):
470pg = self.pg
471local_device_ids = self.rank_to_GPU[self.rank]
472num_gpus = len(local_device_ids)
473
474def gather(output_t, input_t, rootRank):
475opts = c10d.GatherOptions()
476opts.rootRank = rootRank
477if rootRank == self.rank:
478work = pg.gather(output_t, input_t, opts)
479else:
480work = pg.gather([], input_t, opts)
481work.wait()
482
483# init input
484tensors = []
485for device_id in local_device_ids:
486tensors.append(torch.tensor([self.rank]).cuda(device_id))
487
488# init output
489output_ts = []
490for idx in range(num_gpus):
491gpu_idx = local_device_ids[idx]
492output_ts.append([])
493for rank in range(self.world_size):
494output_ts[idx].append(torch.tensor([-1]).cuda(gpu_idx))
495
496expected = [[torch.tensor([rank]) for rank in range(self.world_size)]]
497for rank in range(self.world_size):
498gather(output_ts, tensors, rank)
499if rank == self.rank:
500self.assertEqual(expected, output_ts)
501
502@requires_nccl()
503@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
504def test_gather_stress(self):
505pg = self.pg
506local_device_ids = self.rank_to_GPU[self.rank]
507num_gpus = len(local_device_ids)
508
509def gather(output_t, input_t, rootRank):
510opts = c10d.GatherOptions()
511opts.rootRank = rootRank
512if rootRank == self.rank:
513work = pg.gather(output_t, input_t, opts)
514else:
515work = pg.gather([], input_t, opts)
516work.wait()
517
518stress_length = 1000
519
520# init input
521tensors = []
522for i in range(stress_length):
523tensors.append([])
524for device_id in local_device_ids:
525tensors[i].append(torch.tensor([self.rank]).cuda(device_id))
526
527# init output
528output_ts = []
529for i in range(stress_length):
530output_ts.append([[] for _ in range(num_gpus)])
531for idx, ls in enumerate(output_ts[i]):
532gpu_idx = local_device_ids[idx]
533for _ in range(self.world_size):
534ls.append(torch.tensor([-1]).cuda(gpu_idx))
535
536expected = [[torch.tensor([rank]) for rank in range(self.world_size)]]
537for i in range(stress_length):
538for rank in range(self.world_size):
539gather(output_ts[i], tensors[i], rank)
540# Verification
541if rank == self.rank:
542self.assertEqual(output_ts[i], expected)
543
544@requires_nccl()
545@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
546def test_gather_checks(self):
547pg = self.pg
548device_id = self.rank_to_GPU[self.rank][0]
549
550# init input
551tensor = torch.tensor([self.rank]).cuda(device_id)
552
553# init output
554output_ts = []
555for rank in range(self.world_size):
556output_ts.append(torch.tensor([-1]).cuda(device_id))
557
558with self.assertRaisesRegex(ValueError, "invalid root rank"):
559opts = c10d.GatherOptions()
560opts.rootRank = -1
561pg.gather([output_ts], [tensor], opts)
562
563with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
564pg.gather([output_ts], [tensor], 0)
565
566with self.assertRaisesRegex(ValueError, "invalid root rank"):
567opts = c10d.GatherOptions()
568opts.rootRank = self.world_size
569pg.gather([output_ts], [tensor], opts)
570
571with self.assertRaisesRegex(
572# throws error message from dispatcher
573RuntimeError,
574"There were no tensor arguments to this function",
575):
576opts = c10d.GatherOptions()
577opts.rootRank = 0
578pg.gather([output_ts], [], opts)
579
580@requires_nccl()
581@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
582def test_scatter_ops(self):
583pg = self.pg
584local_device_ids = self.rank_to_GPU[self.rank]
585num_gpus = len(local_device_ids)
586
587def scatter(output_t, input_t, rootRank):
588opts = c10d.ScatterOptions()
589opts.rootRank = rootRank
590if rootRank == self.rank:
591work = pg.scatter(output_t, input_t, opts)
592else:
593work = pg.scatter(output_t, [], opts)
594work.wait()
595
596# init output
597tensors = []
598for device_id in local_device_ids:
599tensors.append(torch.tensor([-1]).cuda(device_id))
600
601# init input
602scatter_list = []
603for idx in range(num_gpus):
604gpu_idx = local_device_ids[idx]
605scatter_list.append([])
606for rank in range(self.world_size):
607scatter_list[idx].append(torch.tensor([rank]).cuda(gpu_idx))
608
609# test each rank to scatter
610expected = [torch.tensor([self.rank])]
611for rank in range(self.world_size):
612scatter(tensors, scatter_list, rank)
613self.assertEqual(expected, tensors)
614
615@requires_nccl()
616@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
617def test_scatter_stress(self):
618pg = self.pg
619local_device_ids = self.rank_to_GPU[self.rank]
620num_gpus = len(local_device_ids)
621
622def scatter(output_t, input_t, rootRank):
623opts = c10d.ScatterOptions()
624opts.rootRank = rootRank
625if rootRank == self.rank:
626work = pg.scatter(output_t, input_t, opts)
627else:
628work = pg.scatter(output_t, [], opts)
629work.wait()
630
631stress_length = 1000
632
633# init output
634tensors = []
635for i in range(stress_length):
636tensors.append([])
637for device_id in local_device_ids:
638tensors[i].append(torch.tensor([-1]).cuda(device_id))
639
640# init input
641scatter_list = []
642for i in range(stress_length):
643scatter_list.append([[] for _ in range(num_gpus)])
644for idx, ls in enumerate(scatter_list[i]):
645gpu_idx = local_device_ids[idx]
646for rank in range(self.world_size):
647ls.append(torch.tensor([rank]).cuda(gpu_idx))
648
649# test each rank to scatter
650expected = [torch.tensor([self.rank])]
651for i in range(stress_length):
652for rank in range(self.world_size):
653scatter(tensors[i], scatter_list[i], rank)
654# Verification
655self.assertEqual(tensors[i], expected)
656
657@requires_nccl()
658@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
659def test_scatter_checks(self):
660pg = self.pg
661local_device_ids = self.rank_to_GPU[self.rank]
662num_gpus = len(local_device_ids)
663
664# init output
665tensors = []
666for device_id in local_device_ids:
667tensors.append(torch.tensor([-1]).cuda(device_id))
668
669# init input
670scatter_list = []
671for idx in range(num_gpus):
672gpu_idx = local_device_ids[idx]
673scatter_list.append([])
674for rank in range(self.world_size):
675scatter_list[idx].append(torch.tensor([rank]).cuda(gpu_idx))
676
677with self.assertRaisesRegex(ValueError, "invalid root rank"):
678opts = c10d.ScatterOptions()
679opts.rootRank = -1
680pg.scatter(tensors, scatter_list, opts)
681
682with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
683pg.scatter(tensors, scatter_list, 0)
684
685with self.assertRaisesRegex(ValueError, "invalid root rank"):
686opts = c10d.ScatterOptions()
687opts.rootRank = self.world_size
688pg.scatter(tensors, scatter_list, opts)
689
690with self.assertRaisesRegex(
691# throws error message from dispatcher
692RuntimeError,
693"There were no tensor arguments to this function",
694):
695opts = c10d.ScatterOptions()
696opts.rootRank = 0
697pg.scatter([], scatter_list, opts)
698
699@requires_nccl()
700@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
701def test_reduce_scatter_base_basics(self):
702pg = self.pg
703local_device_id = self.rank_to_GPU[self.rank][0]
704
705def reduce_scatter_base(output_t, input_t):
706work = pg._reduce_scatter_base(output_t, input_t)
707work.wait()
708
709# anticipate an error
710with self.assertRaisesRegex(
711ValueError,
712"input tensor must be the same size as output size times world size",
713):
714input_t = torch.tensor([self.rank]).cuda(local_device_id)
715output_t = torch.empty((self.world_size + 1), dtype=input_t.dtype).cuda(
716local_device_id
717)
718# fails the check because output_t is not correctly sized
719reduce_scatter_base(output_t, input_t)
720
721# anticipate an error
722with self.assertRaisesRegex(
723TypeError, "input tensor must be the same type as the output tensor."
724):
725tensor = torch.tensor([self.rank], dtype=torch.float).cuda(local_device_id)
726output_t = torch.empty((self.world_size + 1), dtype=torch.long).cuda(
727local_device_id
728)
729# fails the check because the dtype is different
730reduce_scatter_base(output_t, tensor)
731
732@requires_nccl()
733@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
734def test_reduce_scatter_ops(self):
735pg = self.pg
736local_device_ids = self.rank_to_GPU[self.rank]
737num_gpus = len(local_device_ids)
738
739def reduce_scatter(outputs, input_lists, op):
740opts = c10d.ReduceScatterOptions()
741opts.reduceOp = op
742work = pg.reduce_scatter(outputs, input_lists, opts)
743work.wait()
744
745output = [torch.tensor([0]).cuda(i) for i in local_device_ids]
746
747# GPU/rank
748# 0 [1], [2], [3], [4]
749# 1 [2], [3], [4], [5]
750# 2 [3], [4], [5], [6]
751# 3 [4], [5], [6], [7]
752
753# Sum
754tensor_lists = []
755input_per_gpu = []
756
757for i in range(self.world_size):
758input_per_gpu.append(torch.tensor([self.rank + i + 1]))
759
760for gpu in local_device_ids:
761tensor_lists.append([t.cuda(device=gpu) for t in input_per_gpu])
762
763reduce_scatter(output, tensor_lists, c10d.ReduceOp.SUM)
764
765for i in range(num_gpus):
766expected = torch.tensor(
767[
768(1 + self.world_size) * self.world_size // 2
769+ self.world_size * self.rank
770]
771)
772
773self.assertEqual(expected, output[i])
774
775# Min
776reduce_scatter(output, tensor_lists, c10d.ReduceOp.MIN)
777
778for i in range(num_gpus):
779expected = torch.tensor([self.rank + 1 + i])
780self.assertEqual(expected, output[i])
781
782# Max
783reduce_scatter(output, tensor_lists, c10d.ReduceOp.MAX)
784
785for i in range(num_gpus):
786expected = torch.tensor([self.rank + self.world_size + i])
787self.assertEqual(expected, output[i])
788
789# Product
790reduce_scatter(output, tensor_lists, c10d.ReduceOp.PRODUCT)
791
792# math package don't have math.perm until python 3.8, so
793# we implement a naive version here.
794def perm(n, k):
795prod_val = n
796for val in range(n - k + 1, n):
797prod_val *= val
798return prod_val
799
800for i in range(num_gpus):
801prod_val = perm(self.rank + self.world_size, self.world_size)
802
803expected = torch.tensor([prod_val])
804self.assertEqual(expected, output[i])
805
806# Test the input params overridden scenarios, aka, when the input is
807# a list and output is just one tensor.
808# Sum
809output_tensor = torch.empty_like(input_per_gpu[0][0]).cuda(self.rank)
810input_list = [tensor[0].cuda(self.rank) for tensor in input_per_gpu]
811pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.SUM).wait()
812expected = torch.tensor(
813(1 + self.world_size) * self.world_size // 2 + self.world_size * self.rank
814)
815self.assertEqual(expected, output_tensor)
816
817# Min
818pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.MIN).wait()
819expected = torch.tensor(self.rank + 1)
820self.assertEqual(expected, output_tensor)
821
822# Max
823pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.MAX).wait()
824expected = torch.tensor(self.rank + self.world_size)
825self.assertEqual(expected, output_tensor)
826
827# Product
828pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.PRODUCT).wait()
829prod_val = self.rank + 1
830for k in range(1, self.world_size):
831prod_val = prod_val * (self.rank + 1 + k)
832expected = torch.tensor(prod_val)
833self.assertEqual(expected, output_tensor)
834
835if torch.cuda.nccl.version() >= (2, 11, 1):
836for factor in (3.0, torch.tensor([5.0], device=self.rank)):
837if isinstance(factor, torch.Tensor):
838factor_ref = factor.cpu().item()
839else:
840factor_ref = factor
841output = [t.float() for t in output]
842tensor_lists = [[t.float() for t in tl] for tl in tensor_lists]
843output_ref = [t.float() for t in output]
844tensor_lists_ref = [
845[t.float() * factor_ref for t in tl] for tl in tensor_lists
846]
847reduce_scatter(output, tensor_lists, c10d._make_nccl_premul_sum(factor))
848reduce_scatter(output_ref, tensor_lists_ref, c10d.ReduceOp.SUM)
849self.assertEqual(output_ref, output)
850
851@requires_nccl()
852@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
853def test_reduce_scatter_base_ops(self):
854pg = self.pg
855local_device_id = self.rank_to_GPU[self.rank][0]
856
857def reduce_scatter_base(output_t, input_t):
858work = pg._reduce_scatter_base(output_t, input_t)
859work.wait()
860
861# reduce_scatter_base is GPU number agnostic.
862# Each rank contribute one tensor regardless of GPU counts
863output_t = torch.empty([1]).cuda(local_device_id)
864tensor = torch.arange(self.world_size, dtype=output_t.dtype).cuda(
865local_device_id
866)
867
868reduce_scatter_base(output_t, tensor)
869
870# Verification
871self.assertEqual(output_t[0], self.rank * self.world_size)
872
873@requires_nccl()
874@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
875def test_barrier(self):
876pg = self.pg
877local_device_ids = self.rank_to_GPU[self.rank]
878
879def allreduce(tensors):
880opts = c10d.AllreduceOptions()
881work = pg.allreduce(tensors, opts)
882return work
883
884# Making the collective to operate on
885# 1, 2, 3, 4, .... len(local_device_ids) GPUs
886tensors_list = [[] for _ in range(len(local_device_ids))]
887
888for i in range(1, len(local_device_ids) + 1):
889for j in range(i):
890tensors_list[i - 1].append(
891torch.tensor([j + 1]).cuda(local_device_ids[j])
892)
893
894works = []
895for tensors in tensors_list:
896work = allreduce(tensors)
897works.append(work)
898
899# Barrier will ensure that all previous work is completed
900pg.barrier().wait()
901
902for i in range(1, len(local_device_ids) + 1):
903for j in range(i):
904self.assertEqual(
905torch.tensor([(j + 1) * self.world_size]), tensors_list[i - 1][j]
906)
907
908@requires_nccl()
909@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
910def test_send_recv(self):
911pg = self.pg
912device = self.rank_to_GPU[self.rank][0]
913
914# Generate the same random tensor
915torch.manual_seed(0)
916send_tensor = torch.rand(10, 10, device=device)
917if self.rank == 0:
918dist.send(send_tensor, 1)
919if self.rank == 1:
920recv_tensor = torch.rand(10, 10, device=device)
921dist.recv(recv_tensor, 0)
922self.assertEqual(send_tensor, recv_tensor)
923
924@requires_nccl()
925@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
926def test_send_recv_complex(self):
927pg = self.pg
928device = self.rank_to_GPU[self.rank][0]
929
930# Generate the same random tensor
931torch.manual_seed(0)
932send_tensor = torch.rand(10, 10, dtype=torch.cfloat, device=device)
933if self.rank == 0:
934dist.send(send_tensor, 1)
935if self.rank == 1:
936recv_tensor = torch.rand(10, 10, dtype=torch.cfloat, device=device)
937dist.recv(recv_tensor, 0)
938self.assertEqual(send_tensor, recv_tensor)
939
940@requires_nccl()
941@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
942def test_send_recv_object_list(self):
943device = self.rank_to_GPU[self.rank][0]
944
945val = 99 if self.rank == 0 else None
946object_list = [val] * self.world_size
947if self.rank == 0:
948dist.send_object_list(object_list, 1, device=device)
949if self.rank == 1:
950dist.recv_object_list(object_list, 0, device=device)
951self.assertEqual(object_list[0], 99)
952
953@requires_nccl()
954@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
955def test_tensor_register_hook(self):
956os.environ["TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK"] = "1"
957
958pg = self.pg
959local_device_id = self.rank_to_GPU[self.rank][0]
960
961def allgather_base(output_t, input_t):
962work = pg._allgather_base(output_t, input_t)
963work.wait()
964
965# allgather_base is GPU number agnostic.
966# Each rank contribute one tensor regardless of GPU counts
967tensor = torch.tensor([self.rank]).cuda(local_device_id)
968output_t = torch.empty((self.world_size), dtype=tensor.dtype).cuda(
969local_device_id
970)
971
972allgather_base(output_t, tensor)
973
974# Verification
975self.assertEqual(torch.arange(self.world_size), output_t)
976
977# Unset env
978del os.environ["TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK"]
979
980
981if __name__ == "__main__":
982rank = int(os.getenv("RANK", -1))
983world_size = int(os.getenv("WORLD_SIZE", 2))
984
985if rank != -1:
986# Launched with torchrun or other multi-proc launchers. Directly run the test.
987ProcessGroupNCCLOpTest.run_rank(rank, world_size)
988else:
989# Launched as a single process. Spawn subprocess to run the tests.
990# Also need a rendezvous file for `init_process_group` purpose.
991rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
992torch.multiprocessing.spawn(
993ProcessGroupNCCLOpTest.run_rank,
994nprocs=world_size,
995args=(world_size, rdvz_file),
996)
997