pytorch

Форк
0
/
test_symmetric_memory.py 
507 строк · 16.2 Кб
1
# Owner(s): ["module: c10d"]
2

3
import torch
4
import torch.distributed as dist
5
from torch._C._autograd import DeviceType
6
from torch._C._distributed_c10d import _SymmetricMemory
7
from torch.distributed._symmetric_memory import (
8
    _fused_all_gather_matmul_fallback,
9
    _fused_all_gather_scaled_matmul_fallback,
10
    _fused_matmul_reduce_scatter_fallback,
11
    _fused_scaled_matmul_reduce_scatter_fallback,
12
    enable_symm_mem_for_group,
13
    restride_A_for_fused_matmul_reduce_scatter,
14
    restride_A_shard_for_fused_all_gather_matmul,
15
)
16
from torch.testing._internal.common_distributed import (
17
    MultiProcessTestCase,
18
    skip_if_lt_x_gpu,
19
)
20
from torch.testing._internal.common_utils import (
21
    instantiate_parametrized_tests,
22
    parametrize,
23
    run_tests,
24
    skip_but_pass_in_sandcastle_if,
25
    skipIfRocm,
26
)
27

28

29
def requires_cuda_p2p_access():
30
    cuda_p2p_access_available = (
31
        torch.cuda.is_available()
32
        and torch.cuda.get_device_capability() >= (8, 0)
33
        and torch.cuda.device_count() >= 2
34
    )
35
    num_devices = torch.cuda.device_count()
36
    for i in range(num_devices - 1):
37
        for j in range(i + 1, num_devices):
38
            if not torch.cuda.can_device_access_peer(i, j):
39
                cuda_p2p_access_available = False
40
                break
41
        if not cuda_p2p_access_available:
42
            break
43

44
    return skip_but_pass_in_sandcastle_if(
45
        not cuda_p2p_access_available,
46
        "cuda p2p access is not available",
47
    )
48

49

50
def requires_multicast_support():
51
    has_multicast_support = (
52
        torch.cuda.is_available()
53
        and _SymmetricMemory.has_multicast_support(DeviceType.CUDA)
54
    )
55
    return skip_but_pass_in_sandcastle_if(
56
        not has_multicast_support,
57
        "multicast support is not available",
58
    )
59

60

61
@instantiate_parametrized_tests
62
@requires_cuda_p2p_access()
63
class SymmetricMemoryTest(MultiProcessTestCase):
64
    def setUp(self) -> None:
65
        super().setUp()
66
        self._spawn_processes()
67

68
    @property
69
    def world_size(self) -> int:
70
        return 2
71

72
    @property
73
    def device(self) -> torch.device:
74
        return torch.device(f"cuda:{self.rank}")
75

76
    def _init_process(self):
77
        torch.cuda.set_device(self.device)
78
        store = dist.FileStore(self.file_name, self.world_size)
79
        dist.init_process_group(
80
            backend="nccl",
81
            world_size=self.world_size,
82
            rank=self.rank,
83
            store=store,
84
        )
85
        enable_symm_mem_for_group(dist.group.WORLD.group_name)
86

87
    def _verify_symmetric_memory(self, symm_mem):
88
        self.assertEqual(symm_mem.world_size, 2)
89

90
        buf = symm_mem.get_buffer(0, (64, 64), torch.float32)
91
        if symm_mem.rank == 0:
92
            symm_mem.wait_signal(src_rank=1)
93
            self.assertTrue(buf.eq(42).all())
94
        else:
95
            buf.fill_(42)
96
            symm_mem.put_signal(dst_rank=0)
97

98
        symm_mem.barrier()
99

100
        if symm_mem.rank == 0:
101
            symm_mem.barrier()
102
            self.assertTrue(buf.eq(43).all())
103
        else:
104
            buf.fill_(43)
105
            symm_mem.barrier()
106

107
        symm_mem.barrier()
108

109
    @skipIfRocm
110
    @skip_if_lt_x_gpu(2)
111
    def test_cuda_nvlink_connectivity_detection(self) -> None:
112
        from torch._C._distributed_c10d import _detect_dma_connectivity
113

114
        connectivity = _detect_dma_connectivity(DeviceType.CUDA, "nvlink")
115
        self.assertEqual(connectivity.device_type, DeviceType.CUDA)
116
        self.assertEqual(connectivity.connection_type, "nvlink")
117
        self.assertEqual(len(connectivity.matrix), torch.cuda.device_count())
118
        for row in connectivity.matrix:
119
            self.assertEqual(len(row), torch.cuda.device_count())
120

121
    @skipIfRocm
122
    @skip_if_lt_x_gpu(2)
123
    def test_empty_strided_p2p(self) -> None:
124
        self._init_process()
125

126
        shape = (64, 64)
127
        stride = (64, 1)
128
        dtype = torch.float32
129
        device = self.device
130
        group_name = "0"
131
        alloc_args = (shape, stride, dtype, device, group_name)
132

133
        t = torch.empty(shape, dtype=dtype, device=device)
134
        self.assertIsNone(_SymmetricMemory.rendezvous(t))
135

136
        t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
137
        symm_mem = _SymmetricMemory.rendezvous(t)
138

139
        del t
140
        self._verify_symmetric_memory(symm_mem)
141
        dist.destroy_process_group()
142

143
    @skipIfRocm
144
    @skip_if_lt_x_gpu(2)
145
    def test_empty_strided_p2p_persistent(self) -> None:
146
        self._init_process()
147

148
        shape = (64, 64)
149
        stride = (64, 1)
150
        dtype = torch.float32
151
        device = self.device
152
        alloc_id = 42  # Persistent allocation
153
        group_name = "0"
154
        alloc_args = (shape, stride, dtype, device, group_name, alloc_id)
155

156
        t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
157
        data_ptr = t.data_ptr()
158

159
        # Verify that persistent allocation would fail if there's an active
160
        # allocation with the same alloc_id.
161
        with self.assertRaises(RuntimeError):
162
            _SymmetricMemory.empty_strided_p2p(*alloc_args)
163

164
        # Verify that persistent allocation would succeed in lieu of activate
165
        # allocations with the same alloc_id, and the returned tensor would
166
        # have the same data pointer.
167
        del t
168
        t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
169
        self.assertEqual(t.data_ptr(), data_ptr)
170

171
        # Verify that get_symmetric_memory would fail if called before
172
        # rendezvous.
173
        with self.assertRaises(RuntimeError):
174
            _SymmetricMemory.get_symmetric_memory(t)
175

176
        symm_mem_0 = _SymmetricMemory.rendezvous(t)
177
        symm_mem_1 = _SymmetricMemory.get_symmetric_memory(t)
178
        self.assertEqual(id(symm_mem_0), id(symm_mem_1))
179

180
        self._verify_symmetric_memory(symm_mem_0)
181
        dist.destroy_process_group()
182

183
    @skipIfRocm
184
    @skip_if_lt_x_gpu(2)
185
    @parametrize("gather_dim", [0, 1])
186
    def test_fused_all_gather_matmul(self, gather_dim: int) -> None:
187
        self._init_process()
188

189
        BATCH = 8
190
        M = 64
191
        N = 16
192
        K = 32
193
        group = dist.group.WORLD
194
        rank = self.rank
195
        world_size = self.world_size
196

197
        torch.manual_seed(42 + rank)
198
        A_shard = torch.rand(BATCH, M // self.world_size, K, device="cuda")
199
        Bs = [torch.rand(K, N, device="cuda") for _ in range(3)]
200

201
        ag_output_0, mm_outputs_0 = _fused_all_gather_matmul_fallback(
202
            A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name
203
        )
204
        ag_output_1, mm_outputs_1 = torch.ops.symm_mem.fused_all_gather_matmul(
205
            A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name
206
        )
207

208
        assert torch.allclose(ag_output_0, ag_output_1)
209
        assert ag_output_0.stride() == ag_output_1.stride()
210
        for mm_output_0, mm_output_1 in zip(mm_outputs_0, mm_outputs_1):
211
            assert torch.allclose(mm_output_0, mm_output_1)
212
            assert mm_output_0.stride(), mm_output_1.stride()
213

214
        dist.destroy_process_group()
215

216
    @skipIfRocm
217
    @skip_if_lt_x_gpu(2)
218
    @parametrize("gather_dim", [0, 1])
219
    def test_fused_all_gather_scaled_matmul(self, gather_dim: int) -> None:
220
        self._init_process()
221

222
        BATCH = 8
223
        M = 64
224
        N = 16
225
        K = 32
226
        group = dist.group.WORLD
227
        rank = self.rank
228
        world_size = self.world_size
229

230
        torch.manual_seed(42 + rank)
231
        A_shard = torch.rand(BATCH, M // self.world_size, K, device="cuda").to(
232
            torch.float8_e4m3fn
233
        )
234
        A_scale = torch.tensor(0.1, device="cuda")
235
        Bs = [
236
            torch.rand(N, K, device="cuda").to(torch.float8_e4m3fn).T for _ in range(3)
237
        ]
238
        B_scales = [torch.tensor(0.1, device="cuda") for _ in range(3)]
239
        out_dtypes = [None, torch.bfloat16, torch.float32]
240

241
        ag_output_0, mm_outputs_0 = _fused_all_gather_scaled_matmul_fallback(
242
            A_shard,
243
            Bs,
244
            A_scale,
245
            B_scales,
246
            gather_dim=gather_dim,
247
            group_name=group.group_name,
248
            biases=[None] * len(Bs),
249
            result_scales=[None] * len(Bs),
250
            out_dtypes=out_dtypes,
251
            use_fast_accum=[None] * len(Bs),
252
        )
253
        ag_output_1, mm_outputs_1 = torch.ops.symm_mem.fused_all_gather_scaled_matmul(
254
            A_shard,
255
            Bs,
256
            A_scale,
257
            B_scales,
258
            gather_dim=gather_dim,
259
            group_name=group.group_name,
260
            biases=[None] * len(Bs),
261
            result_scales=[None] * len(Bs),
262
            out_dtypes=out_dtypes,
263
            use_fast_accum=[None] * len(Bs),
264
        )
265

266
        self.assertTrue(
267
            torch.allclose(
268
                ag_output_0.to(torch.float32),
269
                ag_output_1.to(torch.float32),
270
            )
271
        )
272
        self.assertEqual(ag_output_0.stride(), ag_output_1.stride())
273
        for mm_output_0, mm_output_1 in zip(mm_outputs_0, mm_outputs_1):
274
            self.assertTrue(
275
                torch.allclose(
276
                    mm_output_0.to(torch.float32), mm_output_1.to(torch.float32)
277
                )
278
            )
279
            self.assertEqual(mm_output_0.stride(), mm_output_1.stride())
280
            self.assertEqual(mm_output_0.dtype, mm_output_1.dtype)
281

282
        dist.destroy_process_group()
283

284
    @skipIfRocm
285
    @skip_if_lt_x_gpu(2)
286
    @parametrize("scatter_dim", [0, 1])
287
    def test_fused_matmul_reduce_scatter(self, scatter_dim: int) -> None:
288
        self._init_process()
289

290
        BATCH = 8
291
        M = 64
292
        N = 16
293
        K = 32
294
        group = dist.group.WORLD
295
        rank = self.rank
296
        world_size = self.world_size
297

298
        torch.manual_seed(42 + rank)
299
        A = torch.rand(BATCH, M, K, device="cuda")
300
        B = torch.rand(K, N, device="cuda")
301

302
        output_0 = _fused_matmul_reduce_scatter_fallback(
303
            A, B, "avg", scatter_dim=scatter_dim, group_name=group.group_name
304
        )
305
        output_1 = torch.ops.symm_mem.fused_matmul_reduce_scatter(
306
            A, B, "avg", scatter_dim=scatter_dim, group_name=group.group_name
307
        )
308

309
        assert torch.allclose(output_0, output_1)
310
        assert output_0.stride() == output_1.stride()
311

312
        dist.destroy_process_group()
313

314
    @skipIfRocm
315
    @skip_if_lt_x_gpu(2)
316
    @parametrize("scatter_dim", [0, 1])
317
    def test_fused_scaled_matmul_reduce_scatter(self, scatter_dim: int) -> None:
318
        self._init_process()
319

320
        BATCH = 8
321
        M = 64
322
        N = 16
323
        K = 32
324
        group = dist.group.WORLD
325
        rank = self.rank
326
        world_size = self.world_size
327

328
        torch.manual_seed(42 + rank)
329
        A = torch.rand(BATCH, M, K, device="cuda").to(torch.float8_e4m3fn)
330
        A_scale = torch.tensor(0.1, device="cuda")
331
        B = torch.rand(N, K, device="cuda").to(torch.float8_e4m3fn).T
332
        B_scale = torch.tensor(0.1, device="cuda")
333

334
        output_0 = _fused_scaled_matmul_reduce_scatter_fallback(
335
            A,
336
            B,
337
            A_scale,
338
            B_scale,
339
            "avg",
340
            scatter_dim,
341
            group.group_name,
342
            out_dtype=torch.bfloat16,
343
        )
344
        output_1 = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
345
            A,
346
            B,
347
            A_scale,
348
            B_scale,
349
            "avg",
350
            scatter_dim,
351
            group.group_name,
352
            out_dtype=torch.bfloat16,
353
        )
354

355
        assert torch.allclose(output_0, output_1)
356
        assert output_0.stride() == output_1.stride()
357

358
        dist.destroy_process_group()
359

360
    @skipIfRocm
361
    @parametrize("dim", [0, 1, 2])
362
    def test_optimal_layout(self, dim: int) -> None:
363
        t = torch.rand(8, 64, 32, 16)
364

365
        x = restride_A_shard_for_fused_all_gather_matmul(t, dim)
366
        self.assertTrue(x.movedim(dim, 0).is_contiguous())
367
        self.assertTrue(torch.allclose(x, t))
368

369
        x = restride_A_for_fused_matmul_reduce_scatter(t, dim)
370
        self.assertTrue(x.movedim(dim, 0).is_contiguous())
371
        self.assertTrue(torch.allclose(x, t))
372

373
    @skipIfRocm
374
    @skip_if_lt_x_gpu(2)
375
    @parametrize("symm_mem_input", [True, False])
376
    def test_low_contention_all_gather(self, symm_mem_input: bool) -> None:
377
        self._init_process()
378

379
        if symm_mem_input:
380
            t = _SymmetricMemory.empty_strided_p2p(
381
                size=(64, 64),
382
                stride=(64, 1),
383
                dtype=torch.float32,
384
                device=self.device,
385
                group_name="0",
386
            ).fill_(self.rank)
387
        else:
388
            t = torch.full((64, 64), self.rank, dtype=torch.float32, device=self.device)
389

390
        res = torch.ops.symm_mem._low_contention_all_gather(t, "0")
391
        res = torch.ops._c10d_functional.wait_tensor(res)
392
        self.assertEqual(res.shape, (64 * self.world_size, 64))
393

394
        chunks = res.chunk(self.world_size)
395
        for r in range(self.world_size):
396
            self.assertTrue(chunks[r].eq(r).all())
397

398
        dist.destroy_process_group()
399

400
    @skipIfRocm
401
    @skip_if_lt_x_gpu(2)
402
    @parametrize("reduce_op", ["sum", "avg"])
403
    @parametrize("symm_mem_input", [True, False])
404
    def test_low_contention_reduce_scatter(
405
        self, reduce_op: str, symm_mem_input: bool
406
    ) -> None:
407
        self._init_process()
408

409
        if symm_mem_input:
410
            t = _SymmetricMemory.empty_strided_p2p(
411
                size=(64, 64),
412
                stride=(64, 1),
413
                dtype=torch.float32,
414
                device=self.device,
415
                group_name="0",
416
            )
417
        else:
418
            t = torch.empty((64, 64), dtype=torch.float32, device=self.device)
419

420
        chunks = t.chunk(self.world_size)
421
        for r in range(self.world_size):
422
            chunks[r].fill_(r)
423

424
        res = torch.ops.symm_mem._low_contention_reduce_scatter(t, reduce_op, "0")
425
        res = torch.ops._c10d_functional.wait_tensor(res)
426
        self.assertEqual(res.shape, (64 // self.world_size, 64))
427

428
        if reduce_op == "sum":
429
            expect = self.rank * self.world_size
430
        elif reduce_op == "avg":
431
            expect = self.rank
432
        else:
433
            raise AssertionError(f"Unexpected reduce_op: {reduce_op}")
434
        self.assertTrue(res.eq(expect).all())
435

436
        dist.destroy_process_group()
437

438
    @skip_if_lt_x_gpu(2)
439
    @requires_multicast_support()
440
    @parametrize("dtype", [torch.float, torch.bfloat16])
441
    @parametrize("align_bytes", [4, 8, 16])
442
    @parametrize("size_bytes", [4, 8192, 8196])
443
    def test_multimem_all_reduce(
444
        self, dtype: torch.dtype, size_bytes: int, align_bytes: int
445
    ) -> None:
446
        self._init_process()
447
        group_name = dist.group.WORLD.group_name
448

449
        t = _SymmetricMemory.empty_strided_p2p(
450
            size=(16384,),
451
            stride=(1,),
452
            dtype=dtype,
453
            device=self.device,
454
            group_name=group_name,
455
        ).fill_(1)
456

457
        self.assertTrue(t.data_ptr() % 16 == 0)
458
        self.assertTrue(align_bytes % t.element_size() == 0)
459
        self.assertTrue(size_bytes % t.element_size() == 0)
460

461
        shift = align_bytes // t.element_size()
462
        numel = size_bytes // t.element_size()
463
        x = t[shift : shift + numel]
464

465
        torch.ops.symm_mem.multimem_all_reduce_(x, "sum", group_name)
466
        self.assertTrue(x.eq(self.world_size).all().item())
467

468
        # Head and tail should not be written
469
        self.assertTrue(t[:shift].eq(1).all().item())
470
        self.assertTrue(t[shift + numel :].eq(1).all().item())
471
        dist.destroy_process_group()
472

473
    @skip_if_lt_x_gpu(2)
474
    @requires_multicast_support()
475
    @parametrize("dtype", [torch.float, torch.bfloat16])
476
    @parametrize("align_bytes", [4, 8, 16])
477
    @parametrize("size_bytes", [4, 8192, 8196])
478
    def test_multimem_one_shot_all_reduce(
479
        self, dtype: torch.dtype, size_bytes: int, align_bytes: int
480
    ) -> None:
481
        self._init_process()
482
        group_name = dist.group.WORLD.group_name
483

484
        t = _SymmetricMemory.empty_strided_p2p(
485
            size=(16384,),
486
            stride=(1,),
487
            dtype=dtype,
488
            device=self.device,
489
            group_name=group_name,
490
        ).fill_(0)
491

492
        self.assertTrue(t.data_ptr() % 16 == 0)
493
        self.assertTrue(align_bytes % t.element_size() == 0)
494
        self.assertTrue(size_bytes % t.element_size() == 0)
495

496
        shift = align_bytes // t.element_size()
497
        numel = size_bytes // t.element_size()
498
        x = t[shift : shift + numel]
499
        x.fill_(1)
500

501
        res = torch.ops.symm_mem.multimem_one_shot_all_reduce(x, "sum", group_name)
502
        self.assertTrue(res.eq(self.world_size).all().item())
503
        dist.destroy_process_group()
504

505

506
if __name__ == "__main__":
507
    run_tests()
508

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.