pytorch

Форк
0
/
test_nccl.py 
237 строк · 7.8 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import sys
4
import torch
5
import torch.cuda.nccl as nccl
6
import torch.cuda
7
import torch.distributed as c10d
8

9
from torch.testing._internal.common_utils import (
10
    TestCase,
11
    run_tests,
12
    IS_WINDOWS,
13
    load_tests,
14
    TEST_WITH_ROCM,
15
    skip_but_pass_in_sandcastle_if,
16
    NoTest,
17
)
18
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU
19
from torch.testing._internal.common_device_type import (
20
    instantiate_device_type_tests,
21
    dtypes,
22
)
23
import re
24

25
HIP_VERSION = (
26
    0.0
27
    if torch.version.hip is None
28
    else float(re.search(r"^\d+\.\d+", torch.version.hip)[0])
29
)
30

31
# load_tests from common_utils is used to automatically filter tests for
32
# sharding on sandcastle. This line silences flake warnings
33
load_tests = load_tests
34

35
nGPUs = torch.cuda.device_count()
36
if not TEST_CUDA:
37
    print("CUDA not available, skipping tests", file=sys.stderr)
38
    TestCase = NoTest  # noqa: F811
39

40

41
datatypes = [torch.float]
42
if (
43
    TEST_CUDA and c10d.is_nccl_available() and nccl.version() >= (2, 10)
44
) or TEST_WITH_ROCM:
45
    datatypes.append(torch.bfloat16)
46

47

48
class TestNCCL(TestCase):
49
    @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")
50
    def test_unique_id(self, device):
51
        uid = nccl.unique_id()
52
        self.assertIsInstance(uid, bytes)
53
        self.assertGreater(len(uid), 1)
54

55
    @skip_but_pass_in_sandcastle_if(
56
        TEST_WITH_ROCM and HIP_VERSION < 3.5, "Skip NCCL tests for ROCm"
57
    )
58
    @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")
59
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "only one GPU detected")
60
    @dtypes(*datatypes)
61
    def test_broadcast(self, device, dtype):
62
        expected = torch.zeros(128).uniform_().to(dtype=dtype)
63
        tensors = [expected.cuda()]
64
        for device in range(1, torch.cuda.device_count()):
65
            tensors.append(torch.zeros(128, dtype=dtype, device=device))
66

67
        nccl.broadcast(tensors)
68
        for i in range(torch.cuda.device_count()):
69
            self.assertEqual(tensors[i], expected)
70

71
        # Test with tuple
72
        tensors = [expected.cuda()]
73
        for device in range(1, torch.cuda.device_count()):
74
            tensors.append(torch.zeros(128, dtype=dtype, device=device))
75

76
        nccl.broadcast(tuple(tensors))
77
        for i in range(torch.cuda.device_count()):
78
            self.assertEqual(tensors[i], expected)
79

80
    @skip_but_pass_in_sandcastle_if(
81
        TEST_WITH_ROCM and HIP_VERSION < 3.5, "Skip NCCL tests for ROCm"
82
    )
83
    @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")
84
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "only one GPU detected")
85
    @dtypes(*datatypes)
86
    def test_reduce(self, device, dtype):
87
        cpu_tensors = [
88
            torch.zeros(128).uniform_().to(dtype=dtype) for i in range(nGPUs)
89
        ]
90
        expected = torch.zeros(128, dtype=dtype)
91
        for t in cpu_tensors:
92
            expected.add_(t)
93

94
        tensors = [cpu_tensors[i].cuda(i) for i in range(nGPUs)]
95
        nccl.reduce(tensors)
96

97
        self.assertEqual(tensors[0], expected)
98

99
        # Test with tuple
100
        tensors = [cpu_tensors[i].cuda(i) for i in range(nGPUs)]
101
        nccl.reduce(tuple(tensors))
102

103
        self.assertEqual(tensors[0], expected)
104

105
    @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")
106
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "only one GPU detected")
107
    @skip_but_pass_in_sandcastle_if(
108
        TEST_WITH_ROCM and HIP_VERSION < 3.5 and dtype == torch.bfloat16,  # noqa: F821
109
        "Skip bfloat16 test for ROCm < 3.5",
110
    )
111
    @dtypes(*datatypes)
112
    def test_all_reduce(self, device, dtype):
113
        cpu_tensors = [
114
            torch.zeros(128).uniform_().to(dtype=dtype) for i in range(nGPUs)
115
        ]
116
        expected = torch.zeros(128, dtype=dtype)
117
        for t in cpu_tensors:
118
            expected.add_(t)
119

120
        tensors = [cpu_tensors[i].cuda(i) for i in range(nGPUs)]
121
        nccl.all_reduce(tensors)
122

123
        for tensor in tensors:
124
            self.assertEqual(tensor, expected)
125

126
        # Test with tuple.
127
        tensors = tuple(cpu_tensors[i].cuda(i) for i in range(nGPUs))
128
        nccl.all_reduce(tensors)
129

130
        for tensor in tensors:
131
            self.assertEqual(tensor, expected)
132

133
        # Test with set.
134
        tensors = {cpu_tensors[i].cuda(i) for i in range(nGPUs)}
135
        nccl.all_reduce(tensors)
136

137
        for tensor in tensors:
138
            self.assertEqual(tensor, expected)
139

140
    @skip_but_pass_in_sandcastle_if(
141
        TEST_WITH_ROCM and HIP_VERSION < 3.5, "Skip NCCL tests for ROCm"
142
    )
143
    @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")
144
    def test_collective_errors(self, device):
145
        t = torch.rand(10).cuda(0)
146
        with self.assertRaisesRegex(
147
            TypeError, "Inputs should be a collection of tensors"
148
        ):
149
            nccl.all_reduce(t)
150

151
        with self.assertRaisesRegex(
152
            TypeError, "Inputs should be a collection of tensors"
153
        ):
154
            nccl.reduce(t)
155

156
        with self.assertRaisesRegex(
157
            TypeError, "Inputs should be a collection of tensors"
158
        ):
159
            nccl.broadcast(t)
160

161
        with self.assertRaisesRegex(
162
            TypeError, "Inputs should be a collection of tensors"
163
        ):
164
            nccl.all_gather(t, t)
165

166
        with self.assertRaisesRegex(
167
            TypeError, "Inputs should be a collection of tensors"
168
        ):
169
            nccl.reduce_scatter(t, t)
170

171
    @skip_but_pass_in_sandcastle_if(
172
        TEST_WITH_ROCM and HIP_VERSION < 3.5, "Skip NCCL tests for ROCm"
173
    )
174
    @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")
175
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "only one GPU detected")
176
    @dtypes(*datatypes)
177
    def test_all_gather(self, device, dtype):
178
        cpu_inputs = [torch.zeros(128).uniform_().to(dtype=dtype) for i in range(nGPUs)]
179
        expected = torch.cat(cpu_inputs, 0)
180

181
        inputs = [cpu_inputs[i].cuda(i) for i in range(nGPUs)]
182
        outputs = [
183
            torch.zeros(128 * nGPUs, device=i, dtype=dtype) for i in range(nGPUs)
184
        ]
185
        nccl.all_gather(inputs, outputs)
186

187
        for tensor in outputs:
188
            self.assertEqual(tensor, expected)
189

190
        # Test with tuple.
191
        inputs = [cpu_inputs[i].cuda(i) for i in range(nGPUs)]
192
        outputs = [
193
            torch.zeros(128 * nGPUs, device=i, dtype=dtype) for i in range(nGPUs)
194
        ]
195
        nccl.all_gather(tuple(inputs), tuple(outputs))
196

197
        for tensor in outputs:
198
            self.assertEqual(tensor, expected)
199

200
    @skip_but_pass_in_sandcastle_if(
201
        TEST_WITH_ROCM and HIP_VERSION < 3.5, "Skip NCCL tests for ROCm"
202
    )
203
    @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")
204
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "only one GPU detected")
205
    @dtypes(*datatypes)
206
    def test_reduce_scatter(self, device, dtype):
207
        in_size = 32 * nGPUs
208
        out_size = 32
209

210
        cpu_inputs = [
211
            torch.zeros(in_size).uniform_().to(dtype=dtype) for i in range(nGPUs)
212
        ]
213
        expected = torch.zeros(in_size, dtype=dtype)
214
        for t in cpu_inputs:
215
            expected.add_(t)
216
        expected = expected.view(nGPUs, 32)
217

218
        inputs = [cpu_inputs[i].cuda(i) for i in range(nGPUs)]
219
        outputs = [torch.zeros(out_size, device=i, dtype=dtype) for i in range(nGPUs)]
220
        nccl.reduce_scatter(inputs, outputs)
221

222
        for i in range(nGPUs):
223
            self.assertEqual(outputs[i], expected[i])
224

225
        # Test with tuple
226
        inputs = [cpu_inputs[i].cuda(i) for i in range(nGPUs)]
227
        outputs = [torch.zeros(out_size, device=i, dtype=dtype) for i in range(nGPUs)]
228
        nccl.reduce_scatter(tuple(inputs), tuple(outputs))
229

230
        for i in range(nGPUs):
231
            self.assertEqual(outputs[i], expected[i])
232

233

234
instantiate_device_type_tests(TestNCCL, globals(), only_for="cuda")
235

236
if __name__ == "__main__":
237
    run_tests()
238

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.