pytorch
237 строк · 7.8 Кб
1# Owner(s): ["oncall: distributed"]
2
3import sys4import torch5import torch.cuda.nccl as nccl6import torch.cuda7import torch.distributed as c10d8
9from torch.testing._internal.common_utils import (10TestCase,11run_tests,12IS_WINDOWS,13load_tests,14TEST_WITH_ROCM,15skip_but_pass_in_sandcastle_if,16NoTest,17)
18from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU19from torch.testing._internal.common_device_type import (20instantiate_device_type_tests,21dtypes,22)
23import re24
25HIP_VERSION = (260.027if torch.version.hip is None28else float(re.search(r"^\d+\.\d+", torch.version.hip)[0])29)
30
31# load_tests from common_utils is used to automatically filter tests for
32# sharding on sandcastle. This line silences flake warnings
33load_tests = load_tests34
35nGPUs = torch.cuda.device_count()36if not TEST_CUDA:37print("CUDA not available, skipping tests", file=sys.stderr)38TestCase = NoTest # noqa: F81139
40
41datatypes = [torch.float]42if (43TEST_CUDA and c10d.is_nccl_available() and nccl.version() >= (2, 10)44) or TEST_WITH_ROCM:45datatypes.append(torch.bfloat16)46
47
48class TestNCCL(TestCase):49@skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")50def test_unique_id(self, device):51uid = nccl.unique_id()52self.assertIsInstance(uid, bytes)53self.assertGreater(len(uid), 1)54
55@skip_but_pass_in_sandcastle_if(56TEST_WITH_ROCM and HIP_VERSION < 3.5, "Skip NCCL tests for ROCm"57)58@skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")59@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "only one GPU detected")60@dtypes(*datatypes)61def test_broadcast(self, device, dtype):62expected = torch.zeros(128).uniform_().to(dtype=dtype)63tensors = [expected.cuda()]64for device in range(1, torch.cuda.device_count()):65tensors.append(torch.zeros(128, dtype=dtype, device=device))66
67nccl.broadcast(tensors)68for i in range(torch.cuda.device_count()):69self.assertEqual(tensors[i], expected)70
71# Test with tuple72tensors = [expected.cuda()]73for device in range(1, torch.cuda.device_count()):74tensors.append(torch.zeros(128, dtype=dtype, device=device))75
76nccl.broadcast(tuple(tensors))77for i in range(torch.cuda.device_count()):78self.assertEqual(tensors[i], expected)79
80@skip_but_pass_in_sandcastle_if(81TEST_WITH_ROCM and HIP_VERSION < 3.5, "Skip NCCL tests for ROCm"82)83@skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")84@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "only one GPU detected")85@dtypes(*datatypes)86def test_reduce(self, device, dtype):87cpu_tensors = [88torch.zeros(128).uniform_().to(dtype=dtype) for i in range(nGPUs)89]90expected = torch.zeros(128, dtype=dtype)91for t in cpu_tensors:92expected.add_(t)93
94tensors = [cpu_tensors[i].cuda(i) for i in range(nGPUs)]95nccl.reduce(tensors)96
97self.assertEqual(tensors[0], expected)98
99# Test with tuple100tensors = [cpu_tensors[i].cuda(i) for i in range(nGPUs)]101nccl.reduce(tuple(tensors))102
103self.assertEqual(tensors[0], expected)104
105@skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")106@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "only one GPU detected")107@skip_but_pass_in_sandcastle_if(108TEST_WITH_ROCM and HIP_VERSION < 3.5 and dtype == torch.bfloat16, # noqa: F821109"Skip bfloat16 test for ROCm < 3.5",110)111@dtypes(*datatypes)112def test_all_reduce(self, device, dtype):113cpu_tensors = [114torch.zeros(128).uniform_().to(dtype=dtype) for i in range(nGPUs)115]116expected = torch.zeros(128, dtype=dtype)117for t in cpu_tensors:118expected.add_(t)119
120tensors = [cpu_tensors[i].cuda(i) for i in range(nGPUs)]121nccl.all_reduce(tensors)122
123for tensor in tensors:124self.assertEqual(tensor, expected)125
126# Test with tuple.127tensors = tuple(cpu_tensors[i].cuda(i) for i in range(nGPUs))128nccl.all_reduce(tensors)129
130for tensor in tensors:131self.assertEqual(tensor, expected)132
133# Test with set.134tensors = {cpu_tensors[i].cuda(i) for i in range(nGPUs)}135nccl.all_reduce(tensors)136
137for tensor in tensors:138self.assertEqual(tensor, expected)139
140@skip_but_pass_in_sandcastle_if(141TEST_WITH_ROCM and HIP_VERSION < 3.5, "Skip NCCL tests for ROCm"142)143@skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")144def test_collective_errors(self, device):145t = torch.rand(10).cuda(0)146with self.assertRaisesRegex(147TypeError, "Inputs should be a collection of tensors"148):149nccl.all_reduce(t)150
151with self.assertRaisesRegex(152TypeError, "Inputs should be a collection of tensors"153):154nccl.reduce(t)155
156with self.assertRaisesRegex(157TypeError, "Inputs should be a collection of tensors"158):159nccl.broadcast(t)160
161with self.assertRaisesRegex(162TypeError, "Inputs should be a collection of tensors"163):164nccl.all_gather(t, t)165
166with self.assertRaisesRegex(167TypeError, "Inputs should be a collection of tensors"168):169nccl.reduce_scatter(t, t)170
171@skip_but_pass_in_sandcastle_if(172TEST_WITH_ROCM and HIP_VERSION < 3.5, "Skip NCCL tests for ROCm"173)174@skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")175@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "only one GPU detected")176@dtypes(*datatypes)177def test_all_gather(self, device, dtype):178cpu_inputs = [torch.zeros(128).uniform_().to(dtype=dtype) for i in range(nGPUs)]179expected = torch.cat(cpu_inputs, 0)180
181inputs = [cpu_inputs[i].cuda(i) for i in range(nGPUs)]182outputs = [183torch.zeros(128 * nGPUs, device=i, dtype=dtype) for i in range(nGPUs)184]185nccl.all_gather(inputs, outputs)186
187for tensor in outputs:188self.assertEqual(tensor, expected)189
190# Test with tuple.191inputs = [cpu_inputs[i].cuda(i) for i in range(nGPUs)]192outputs = [193torch.zeros(128 * nGPUs, device=i, dtype=dtype) for i in range(nGPUs)194]195nccl.all_gather(tuple(inputs), tuple(outputs))196
197for tensor in outputs:198self.assertEqual(tensor, expected)199
200@skip_but_pass_in_sandcastle_if(201TEST_WITH_ROCM and HIP_VERSION < 3.5, "Skip NCCL tests for ROCm"202)203@skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")204@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "only one GPU detected")205@dtypes(*datatypes)206def test_reduce_scatter(self, device, dtype):207in_size = 32 * nGPUs208out_size = 32209
210cpu_inputs = [211torch.zeros(in_size).uniform_().to(dtype=dtype) for i in range(nGPUs)212]213expected = torch.zeros(in_size, dtype=dtype)214for t in cpu_inputs:215expected.add_(t)216expected = expected.view(nGPUs, 32)217
218inputs = [cpu_inputs[i].cuda(i) for i in range(nGPUs)]219outputs = [torch.zeros(out_size, device=i, dtype=dtype) for i in range(nGPUs)]220nccl.reduce_scatter(inputs, outputs)221
222for i in range(nGPUs):223self.assertEqual(outputs[i], expected[i])224
225# Test with tuple226inputs = [cpu_inputs[i].cuda(i) for i in range(nGPUs)]227outputs = [torch.zeros(out_size, device=i, dtype=dtype) for i in range(nGPUs)]228nccl.reduce_scatter(tuple(inputs), tuple(outputs))229
230for i in range(nGPUs):231self.assertEqual(outputs[i], expected[i])232
233
234instantiate_device_type_tests(TestNCCL, globals(), only_for="cuda")235
236if __name__ == "__main__":237run_tests()238