pytorch

Форк
0
/
test_device_mesh.py 
711 строк · 26.8 Кб
1
# Copyright (c) Meta Platforms, Inc. and affiliates
2
# Owner(s): ["oncall: distributed"]
3
import os
4

5
import torch
6
import torch.distributed._functional_collectives as funcol
7
from torch.distributed._tensor import DTensor
8
from torch.distributed._tensor._collective_utils import (
9
    mesh_all_to_all,
10
    mesh_broadcast,
11
    mesh_scatter,
12
)
13
from torch.distributed._tensor.placement_types import _Partial, Shard
14
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh
15

16
from torch.distributed.distributed_c10d import (
17
    get_global_rank,
18
    get_world_size,
19
    init_process_group,
20
    is_initialized,
21
    is_nccl_available,
22
    ProcessGroup,
23
)
24
from torch.testing._internal.common_distributed import run_with_both_funcol_impls
25
from torch.testing._internal.common_utils import (
26
    instantiate_parametrized_tests,
27
    run_tests,
28
)
29
from torch.testing._internal.distributed._tensor.common_dtensor import (
30
    DTensorTestBase,
31
    skip_unless_torch_gpu,
32
    with_comms,
33
)
34
from torch.testing._internal.distributed.fake_pg import FakeStore
35

36

37
def _get_device_type(world_size):
38
    if (
39
        torch.cuda.is_available()
40
        and torch.cuda.device_count() >= world_size
41
        and is_nccl_available()
42
    ):
43
        device_type = "cuda"
44
    else:
45
        device_type = "cpu"
46
    return device_type
47

48

49
def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0):
50
    os.environ["MASTER_ADDR"] = addr
51
    os.environ["MASTER_PORT"] = port
52
    os.environ["WORLD_SIZE"] = f"{world_size}"
53
    os.environ["RANK"] = f"{rank}"
54

55

56
@instantiate_parametrized_tests
57
class DeviceMeshTest(DTensorTestBase):
58
    @property
59
    def world_size(self):
60
        return 4
61

62
    @run_with_both_funcol_impls
63
    def test_init_process_group(self):
64
        device_type = _get_device_type(self.world_size)
65
        mesh_tensor = torch.arange(4).reshape(2, 2)
66
        self.assertTrue(not is_initialized())
67
        _set_env_var(world_size=self.world_size, rank=self.rank)
68
        DeviceMesh(device_type, mesh_tensor)
69
        self.assertTrue(is_initialized())
70
        self.destroy_pg()
71

72
    @with_comms
73
    @skip_unless_torch_gpu
74
    def test_assert_invalid_mesh_tensor(self):
75
        mesh = torch.arange(self.world_size).to(self.rank)
76
        with self.assertRaises(ValueError):
77
            device_mesh = DeviceMesh(self.device_type, mesh)
78

79
    @with_comms
80
    @run_with_both_funcol_impls
81
    def test_get_group(self):
82
        mesh_shape = (2, self.world_size // 2)
83
        mesh_2d = init_device_mesh(
84
            self.device_type, mesh_shape, mesh_dim_names=("dp", "tp")
85
        )
86

87
        tp_mesh = mesh_2d["tp"]
88
        dp_mesh = mesh_2d["dp"]
89

90
        self.assertEqual(len(mesh_2d.get_group()), 2)
91
        self.assertEqual(mesh_2d.get_group()[0], mesh_2d.get_group("dp"))
92
        self.assertEqual(mesh_2d.get_group()[1], mesh_2d.get_group("tp"))
93

94
        self.assertEqual(mesh_2d.get_group(0), mesh_2d.get_group("dp"))
95
        self.assertEqual(mesh_2d.get_group(1), mesh_2d.get_group("tp"))
96

97
        self.assertEqual(mesh_2d.get_group("dp"), dp_mesh.get_group())
98
        self.assertEqual(mesh_2d.get_group("tp"), tp_mesh.get_group())
99

100
    @with_comms
101
    @run_with_both_funcol_impls
102
    def test_get_local_rank_raises_exception(self):
103
        mesh_shape = (2, self.world_size // 2)
104
        mesh_2d = init_device_mesh(
105
            self.device_type, mesh_shape, mesh_dim_names=("dp", "tp")
106
        )
107

108
        with self.assertRaisesRegex(
109
            RuntimeError,
110
            "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
111
        ):
112
            local_rank = mesh_2d.get_local_rank()
113

114
    @with_comms
115
    @run_with_both_funcol_impls
116
    def test_get_local_rank(self):
117
        mesh_shape = (2, self.world_size // 2)
118
        mesh_2d = init_device_mesh(
119
            self.device_type, mesh_shape, mesh_dim_names=("dp", "tp")
120
        )
121
        self.assertEqual(mesh_2d.get_local_rank("dp"), mesh_2d.get_local_rank(0))
122
        self.assertEqual(mesh_2d.get_local_rank("tp"), mesh_2d.get_local_rank(1))
123

124
        dp_mesh = mesh_2d["dp"]
125
        tp_mesh = mesh_2d["tp"]
126
        self.assertEqual(dp_mesh.get_local_rank(), mesh_2d.get_local_rank("dp"))
127
        self.assertEqual(tp_mesh.get_local_rank(), mesh_2d.get_local_rank("tp"))
128

129
    @with_comms
130
    @run_with_both_funcol_impls
131
    def test_device_mesh_2d(self):
132
        mesh_tensor = torch.arange(4).reshape(2, 2)
133
        # construct a cuda device mesh
134
        mesh = DeviceMesh(self.device_type, mesh_tensor)
135

136
        # check all dim groups
137
        dim_to_subgroups = mesh.get_group()
138

139
        expected_ranks_by_dim = [[[0, 2], [1, 3]], [[0, 1], [2, 3]]]
140
        for dim, dim_group in enumerate(dim_to_subgroups):
141
            self.assertTrue(dim < 2)
142
            dim_ranks = expected_ranks_by_dim[dim]
143

144
            dim_group_size = get_world_size(dim_group)
145
            self.assertIsInstance(dim_group, ProcessGroup)
146
            self.assertEqual(dim_group_size, 2)
147
            global_ranks = [
148
                get_global_rank(dim_group, i) for i in range(dim_group_size)
149
            ]
150
            current_rank_expected_group_ranks = (
151
                dim_ranks[0] if self.rank in dim_ranks[0] else dim_ranks[1]
152
            )
153
            self.assertEqual(global_ranks, current_rank_expected_group_ranks)
154

155
    @run_with_both_funcol_impls
156
    def test_fake_pg_device_mesh(self):
157
        fake_store = FakeStore()
158
        init_process_group("fake", store=fake_store, rank=0, world_size=self.world_size)
159
        device_type = "cuda" if torch.cuda.is_available() else "cpu"
160
        mesh = DeviceMesh(device_type, torch.arange(self.world_size))
161

162
        local_tensor = torch.randn(2, 8)
163
        global_tensor = funcol.all_gather_tensor(
164
            local_tensor, gather_dim=0, group=(mesh, 0)
165
        )
166
        self.assertEqual(global_tensor.shape, (self.world_size * 2, 8))
167

168

169
class DeviceMeshTestNDim(DTensorTestBase):
170
    @property
171
    def world_size(self):
172
        return 8
173

174
    @with_comms
175
    @run_with_both_funcol_impls
176
    def test_device_mesh_nd(self):
177
        # construct a cuda device mesh
178
        mesh_tensor = torch.arange(8).reshape(2, 2, 2)
179
        mesh = DeviceMesh(self.device_type, mesh_tensor)
180

181
        # check all dim groups
182
        dim_to_subgroups = mesh.get_group()
183

184
        for dim, dim_group in enumerate(dim_to_subgroups):
185
            self.assertTrue(dim < mesh_tensor.ndim)
186
            dim_ranks = mesh_tensor.swapdims(-1, dim).reshape(-1, 2)
187

188
            dim_group_size = get_world_size(dim_group)
189
            self.assertIsInstance(dim_group, ProcessGroup)
190
            self.assertEqual(dim_group_size, 2)
191
            global_ranks = [
192
                get_global_rank(dim_group, i) for i in range(dim_group_size)
193
            ]
194
            for ranks in dim_ranks:
195
                if self.rank in ranks:
196
                    self.assertEqual(global_ranks, ranks.tolist())
197

198
    @with_comms
199
    @run_with_both_funcol_impls
200
    def test_device_mesh_hash(self):
201
        mesh_tensor_2d = torch.arange(8).reshape(4, 2)
202
        mesh = DeviceMesh(self.device_type, mesh_tensor_2d)
203
        mesh2 = DeviceMesh(self.device_type, mesh_tensor_2d)
204
        self.assertNotEqual(hash(mesh), hash(mesh2))
205
        mesh_tensor_3d = torch.arange(8).reshape(2, 2, 2)
206
        mesh3 = DeviceMesh(self.device_type, mesh_tensor_3d)
207
        self.assertNotEqual(hash(mesh), hash(mesh3))
208
        self.assertNotEqual(hash(mesh2), hash(mesh3))
209

210

211
class InitDeviceMeshTest(DTensorTestBase):
212
    @property
213
    def world_size(self):
214
        return 8
215

216
    @with_comms
217
    @run_with_both_funcol_impls
218
    def test_init_device_mesh(self):
219
        mesh_shape = (2, 4)
220
        ref_mesh = DeviceMesh(self.device_type, torch.arange(8).view(mesh_shape))
221

222
        # test init_device_mesh with mesh_dim_names
223
        mesh_dim_names = ("DP", "TP")
224
        mesh_2d = init_device_mesh(
225
            self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
226
        )
227
        self.assertEqual(mesh_2d, ref_mesh)
228
        self.assertEqual(mesh_2d.mesh_dim_names, mesh_dim_names)
229

230
        # test init_device_mesh without mesh_dim_names
231
        mesh_2d = init_device_mesh(self.device_type, mesh_shape)
232
        self.assertEqual(mesh_2d, ref_mesh)
233

234
    @with_comms
235
    @run_with_both_funcol_impls
236
    def test_raises_duplicate_mesh_dim_names(self):
237
        with self.assertRaisesRegex(
238
            RuntimeError,
239
            "Each mesh_dim_name must be unique.",
240
        ):
241
            mesh = init_device_mesh(
242
                self.device_type,
243
                (2, 4),
244
                mesh_dim_names=["dp", "dp"],
245
            )
246

247
    @with_comms
248
    @run_with_both_funcol_impls
249
    def test_raises_mesh_shape_mesh_dim_names_mismatch(self):
250
        with self.assertRaisesRegex(
251
            RuntimeError,
252
            "mesh_shape and mesh_dim_names should have same length!",
253
        ):
254
            mesh = init_device_mesh(
255
                self.device_type,
256
                (8,),
257
                mesh_dim_names=["dp", "tp"],
258
            )
259

260

261
@instantiate_parametrized_tests
262
class TestDeviceMeshGetItem(DTensorTestBase):
263
    @property
264
    def world_size(self):
265
        return 8
266

267
    @with_comms
268
    @run_with_both_funcol_impls
269
    def test_raises_no_mesh_dim_found(self):
270
        with self.assertRaisesRegex(KeyError, "No `mesh_dim_names` found."):
271
            mesh = init_device_mesh(self.device_type, (2, 4))
272
            child_mesh = mesh["DP"]
273

274
    @with_comms
275
    @run_with_both_funcol_impls
276
    def test_raises_invalid_mesh_dim_name(self):
277
        child_mesh_dim_name = "PP"
278
        with self.assertRaisesRegex(
279
            KeyError, f"Mesh dimension '{child_mesh_dim_name}' does not exist."
280
        ):
281
            mesh_dim_names = ("DP", "TP")
282
            mesh = init_device_mesh(
283
                self.device_type, (2, 4), mesh_dim_names=mesh_dim_names
284
            )
285
            child_mesh = mesh[child_mesh_dim_name]
286

287
    @with_comms
288
    @run_with_both_funcol_impls
289
    def test_get_item(self):
290
        mesh_shape = (2, 4)
291
        mesh_dim_names = ("DP", "TP")
292
        mesh_2d = init_device_mesh(
293
            self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
294
        )
295

296
        pg_ranks_by_dim_name = {}
297
        for mesh_dim_name in mesh_dim_names:
298
            mesh_dim = mesh_dim_names.index(mesh_dim_name)
299
            pg_ranks_by_dim_name[mesh_dim_name] = mesh_2d.mesh.swapdims(
300
                -1, mesh_dim
301
            ).reshape(-1, mesh_2d.mesh.size(mesh_dim))
302

303
        tp_mesh = mesh_2d["TP"]
304
        tp_group_idx = self.rank // 4
305
        self.assertEqual(tp_mesh.mesh, pg_ranks_by_dim_name["TP"][tp_group_idx])
306

307
        dp_mesh = mesh_2d["DP"]
308
        dp_group_idx = self.rank % 4
309
        self.assertEqual(mesh_2d["DP"].mesh, pg_ranks_by_dim_name["DP"][dp_group_idx])
310

311
    @with_comms
312
    @run_with_both_funcol_impls
313
    def test_get_item_1d(self):
314
        mesh = init_device_mesh(self.device_type, (8,), mesh_dim_names=("dp",))
315
        # Make sure slicing out 1D mesh from a 1D mesh works.
316
        # We are just dummy return without the parent mesh here.
317
        dp_mesh = mesh["dp"]
318
        self.assertEqual(dp_mesh, mesh)
319

320
        with self.assertRaisesRegex(RuntimeError, "Invalid mesh_dim_name"):
321
            dp_mesh = mesh["dim0"]
322

323

324
@instantiate_parametrized_tests
325
class TestMeshEnv(DTensorTestBase):
326
    @with_comms
327
    @run_with_both_funcol_impls
328
    def test_get_parent_mesh(self):
329
        mesh_shape = (2, self.world_size // 2)
330
        mesh_dim_names = ("DP", "TP")
331
        mesh_2d = init_device_mesh(
332
            self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
333
        )
334

335
        self.assertEqual(_mesh_resources.get_parent_mesh(mesh_2d["DP"]), mesh_2d)
336
        self.assertEqual(_mesh_resources.get_parent_mesh(mesh_2d["TP"]), mesh_2d)
337

338
        mesh_0_2 = DeviceMesh(self.device_type, [0, 2])
339
        mesh_1_3 = DeviceMesh(self.device_type, [1, 3])
340

341
        self.assertEqual(_mesh_resources.get_parent_mesh(mesh_2d["DP"]), mesh_2d)
342
        self.assertEqual(_mesh_resources.get_parent_mesh(mesh_2d["TP"]), mesh_2d)
343
        self.assertEqual(_mesh_resources.get_parent_mesh(mesh_0_2), None)
344
        self.assertEqual(_mesh_resources.get_parent_mesh(mesh_1_3), None)
345

346
    @with_comms
347
    @run_with_both_funcol_impls
348
    def test_get_parent_mesh_dim_exist(self):
349
        mesh_shape = (2, self.world_size // 2)
350
        mesh_dim_names = ("DP", "TP")
351
        mesh_2d = init_device_mesh(
352
            self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
353
        )
354

355
        self.assertEqual(_mesh_resources.get_parent_mesh_dim(mesh_2d["DP"]), 0)
356
        self.assertEqual(_mesh_resources.get_parent_mesh_dim(mesh_2d["TP"]), 1)
357

358
    @with_comms
359
    @run_with_both_funcol_impls
360
    def test_get_parent_mesh_dim_not_exist(self):
361
        mesh_shape = (self.world_size,)
362
        mesh = init_device_mesh(self.device_type, mesh_shape)
363

364
        self.assertEqual(_mesh_resources.get_parent_mesh_dim(mesh), None)
365

366
    @with_comms
367
    @run_with_both_funcol_impls
368
    def test_get_mesh_dim_by_name(self):
369
        mesh_shape = (2, self.world_size // 2)
370
        mesh_dim_names = ("DP", "TP")
371
        mesh_2d = init_device_mesh(
372
            self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
373
        )
374

375
        self.assertEqual(_mesh_resources.get_mesh_dim_by_name(mesh_2d, "DP"), 0)
376
        self.assertEqual(_mesh_resources.get_mesh_dim_by_name(mesh_2d, "TP"), 1)
377

378

379
@instantiate_parametrized_tests
380
class DeviceMeshCollectiveTest(DTensorTestBase):
381
    @property
382
    def world_size(self):
383
        return 8
384

385
    @with_comms
386
    @run_with_both_funcol_impls
387
    def test_broadcast_1d(self):
388
        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
389
        local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank
390
        mesh_broadcast(local_tensor, mesh, mesh_dim=0)
391
        self.assertEqual(local_tensor, torch.zeros(3, 3))
392

393
    @with_comms
394
    @run_with_both_funcol_impls
395
    def test_scatter_1d(self):
396
        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
397
        scatter_tensor_shape = [3, 3, 3]
398
        for scatter_dim in range(len(scatter_tensor_shape)):
399
            shard_placement = Shard(scatter_dim)
400
            scatter_tensor_shape[scatter_dim] *= self.world_size
401
            # make the random seed same across rank
402
            torch.manual_seed(0)
403
            global_tensor = torch.randn(scatter_tensor_shape, device=self.device_type)
404
            splitted_list, _ = shard_placement._split_tensor(
405
                global_tensor, mesh.size(), with_padding=True, contiguous=True
406
            )
407
            recv_tensor = torch.empty_like(splitted_list[mesh.get_rank()])
408
            # scatter on dim > 0 would generate non-contiguous tensor, verify that works
409
            mesh_scatter(recv_tensor, splitted_list, mesh, mesh_dim=0)
410
            self.assertEqual(recv_tensor, splitted_list[mesh.get_rank()])
411

412
    @with_comms
413
    @run_with_both_funcol_impls
414
    def test_scatter_uneven(self):
415
        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
416
        my_rank = device_mesh.get_rank()
417
        tensor_to_split = torch.randn(
418
            device_mesh.size() + 3, device_mesh.size() + 1, device=self.device_type
419
        )
420

421
        for shard_dim in range(tensor_to_split.ndim):
422
            shard_placement = Shard(shard_dim)
423

424
            tensor_to_scatter = tensor_to_split.clone()
425
            tensor_splitted_list = list(
426
                torch.chunk(tensor_to_split, self.world_size, dim=shard_dim)
427
            )
428
            for _ in range(self.world_size - len(tensor_splitted_list)):
429
                tensor_splitted_list.append(torch.tensor([], device=self.device_type))
430

431
            padded_tensor_list, pad_sizes = shard_placement._split_tensor(
432
                tensor_to_scatter,
433
                device_mesh.size(),
434
                with_padding=True,
435
                contiguous=True,
436
            )
437

438
            scattered_tensor = torch.empty_like(padded_tensor_list[my_rank])
439
            mesh_scatter(scattered_tensor, padded_tensor_list, device_mesh, mesh_dim=0)
440

441
            if pad_sizes[my_rank] != 0:
442
                scattered_tensor = shard_placement._unpad_tensor(
443
                    scattered_tensor, pad_sizes[my_rank]
444
                )
445

446
            if scattered_tensor.numel() == 0:
447
                # We need to check numel() instead of size if a tensor is ([]) after unpadding,
448
                # since the size could be ([0, 8]) after unpadding.
449
                self.assertEqual(
450
                    scattered_tensor.numel(), tensor_splitted_list[my_rank].numel()
451
                )
452
            else:
453
                self.assertEqual(
454
                    scattered_tensor.size(), tensor_splitted_list[my_rank].size()
455
                )
456
                self.assertEqual(scattered_tensor, tensor_splitted_list[my_rank])
457

458
    @with_comms
459
    @run_with_both_funcol_impls
460
    def test_all_gather_uneven(self):
461
        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
462
        my_rank = device_mesh.get_rank()
463
        tensor_to_split = torch.ones(
464
            device_mesh.size() + 3,
465
            device_mesh.size() + 1,
466
            device=self.device_type,
467
        )
468

469
        for shard_dim in range(tensor_to_split.ndim):
470
            shard_placement = Shard(shard_dim)
471
            tensor_padded_list, pad_sizes = shard_placement._split_tensor(
472
                tensor_to_split,
473
                device_mesh.size(),
474
                with_padding=True,
475
                contiguous=True,
476
            )
477
            local_tensor = tensor_padded_list[my_rank]
478
            big_tensor = funcol.all_gather_tensor(
479
                local_tensor, gather_dim=shard_dim, group=(device_mesh, 0)
480
            )
481
            big_tensor_chunks = list(
482
                torch.chunk(big_tensor, device_mesh.size(), dim=shard_dim)
483
            )
484
            unpadded_list = [
485
                shard_placement._unpad_tensor(big_tensor_chunks[i], pad_sizes[i])
486
                if pad_sizes[i] > 0
487
                else big_tensor_chunks[i]
488
                for i, big_tensor in enumerate(big_tensor_chunks)
489
            ]
490
            all_gathered_tensor = torch.cat(unpadded_list, dim=shard_dim)
491

492
            self.assertEqual(all_gathered_tensor.size(), tensor_to_split.size())
493
            self.assertEqual(all_gathered_tensor, tensor_to_split)
494

495
    @with_comms
496
    @run_with_both_funcol_impls
497
    def test_reduce_scatter_contiguous(self):
498
        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
499
        my_rank = device_mesh.get_rank()
500

501
        # Init the tensor
502
        step = self.world_size * 2
503
        total_elem = step**2
504
        tensor = torch.arange(0, total_elem).view(step, -1).to(device=self.device_type)
505
        tensor = tensor * (my_rank + 1)
506

507
        # Get non-contiguous tensor by slicing
508
        tensor_to_reduce = tensor[::2, :2]
509
        tensor_contiguous = tensor_to_reduce.clone().contiguous()
510

511
        # Partial to Shard to trigger reduce_scatter
512
        tensor_to_reduce = DTensor.from_local(
513
            tensor_to_reduce, device_mesh, [_Partial()]
514
        )
515
        tensor_contiguous = DTensor.from_local(
516
            tensor_contiguous, device_mesh, [_Partial()]
517
        )
518
        new_tensor = tensor_to_reduce.redistribute(device_mesh, [Shard(0)])
519
        new_tensor_contiguous = tensor_contiguous.redistribute(device_mesh, [Shard(0)])
520

521
        # The output for contiguous and non-contiguous tensors of the same value
522
        # should return the same reducescatter value.
523
        new_tensor_local = new_tensor._local_tensor
524
        new_tensor_contiguous_local = new_tensor_contiguous._local_tensor
525
        self.assertEqual(new_tensor_local, new_tensor_contiguous_local)
526
        self.assertEqual(list(new_tensor_local.size()), [1, 2])
527

528
        # Check the reduce numerical value
529
        sum_base = (1 + self.world_size) * self.world_size / 2
530
        first_elem = my_rank * sum_base * step * 2
531
        expected_tensor = torch.tensor(
532
            [[first_elem, first_elem + sum_base]],
533
            dtype=new_tensor_local.dtype,
534
            device=self.device_type,
535
        )
536
        self.assertEqual(new_tensor_local, expected_tensor)
537

538
    @with_comms
539
    @run_with_both_funcol_impls
540
    def test_reduce_scatter_uneven(self):
541
        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
542
        my_rank = device_mesh.get_rank()
543
        tensor_to_split = (
544
            torch.ones(
545
                device_mesh.size() + 3,
546
                device_mesh.size() + 1,
547
                device=self.device_type,
548
            )
549
            * self.rank
550
        )
551

552
        for shard_dim in range(tensor_to_split.ndim):
553
            shard_placement = Shard(shard_dim)
554
            tensor_to_scatter = tensor_to_split.clone()
555

556
            tensor_splitted_list = list(
557
                torch.chunk(tensor_to_split, self.world_size, dim=shard_dim)
558
            )
559
            for _ in range(self.world_size - len(tensor_splitted_list)):
560
                tensor_splitted_list.append(torch.tensor([], device=self.device_type))
561

562
            padded_tensor_list, pad_sizes = shard_placement._split_tensor(
563
                tensor_to_scatter,
564
                device_mesh.size(),
565
                with_padding=True,
566
                contiguous=True,
567
            )
568

569
            tensor_to_reduce = torch.cat(padded_tensor_list, shard_dim)
570

571
            res_num = ((0 + self.world_size - 1) * self.world_size) / 2
572

573
            scattered_tensor = funcol.reduce_scatter_tensor(
574
                tensor_to_reduce,
575
                reduceOp="sum",
576
                scatter_dim=shard_dim,
577
                group=(device_mesh, 0),
578
            )
579

580
            # unpad scattered_tensor
581
            if pad_sizes[my_rank] > 0:
582
                scattered_tensor = shard_placement._unpad_tensor(
583
                    scattered_tensor, pad_sizes[my_rank]
584
                )
585

586
            if scattered_tensor.numel() == 0:
587
                # We need to check numel() instead of size if a tensor is ([]) after unpadding,
588
                # since the size could be ([0, 8]) after unpadding.
589
                self.assertEqual(
590
                    scattered_tensor.numel(), tensor_splitted_list[my_rank].numel()
591
                )
592
            else:
593
                self.assertEqual(
594
                    scattered_tensor.size(), tensor_splitted_list[my_rank].size()
595
                )
596
                self.assertEqual(
597
                    scattered_tensor,
598
                    torch.ones_like(tensor_splitted_list[my_rank]) * res_num,
599
                )
600

601
    @with_comms
602
    @run_with_both_funcol_impls
603
    def test_broadcast_nd(self):
604
        mesh_tensor = torch.arange(8).reshape(2, 2, 2)
605
        mesh = DeviceMesh(self.device_type, mesh_tensor)
606
        local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank
607

608
        # check all dim groups
609
        dim_to_subgroups = mesh.get_group()
610
        for dim, dim_group in enumerate(dim_to_subgroups):
611
            dim_group_size = get_world_size(dim_group)
612
            global_ranks = [
613
                get_global_rank(dim_group, i) for i in range(dim_group_size)
614
            ]
615
            cloned_local_tensor = local_tensor.clone()
616
            mesh_broadcast(cloned_local_tensor, mesh, mesh_dim=dim)
617
            res_num = global_ranks[0]
618
            self.assertEqual(cloned_local_tensor, torch.ones(3, 3) * res_num)
619

620
    @with_comms
621
    @run_with_both_funcol_impls
622
    def test_scatter_nd(self):
623
        mesh_tensor = torch.arange(8).reshape(2, 2, 2)
624
        mesh = DeviceMesh(self.device_type, mesh_tensor)
625

626
        # check all dim groups
627
        dim_to_subgroups = mesh.get_group()
628
        for dim, dim_group in enumerate(dim_to_subgroups):
629
            dim_group_size = get_world_size(dim_group)
630
            global_ranks = [
631
                get_global_rank(dim_group, i) for i in range(dim_group_size)
632
            ]
633
            scattered_tensors = [
634
                torch.ones(3, 3, device=self.device_type) * global_rank
635
                for global_rank in global_ranks
636
            ]
637
            received_tensor = torch.empty_like(
638
                scattered_tensors[mesh.get_coordinate()[dim]]
639
            )
640
            mesh_scatter(received_tensor, scattered_tensors, mesh, mesh_dim=dim)
641
            self.assertEqual(received_tensor, torch.ones(3, 3) * self.rank)
642

643
    @with_comms
644
    @run_with_both_funcol_impls
645
    def test_all_to_all_1d(self):
646
        # transpose on a 2D tensor distributed over N nodes:
647
        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
648
        tensor_shape = [3, 3]
649
        input_tensor_list = [
650
            torch.ones(*tensor_shape, device=self.device_type)
651
            * (rank + self.rank * self.world_size)
652
            for rank in range(self.world_size)
653
        ]
654
        expected_tensor_list = [
655
            torch.ones(tensor_shape, device=self.device_type)
656
            * (self.rank + rank * self.world_size)  # i.e. transpose
657
            for rank in range(self.world_size)
658
        ]
659
        for scatter_dim in range(len(tensor_shape)):
660
            output_tensor_list = [
661
                torch.empty_like(input_tensor_list[idx])
662
                for idx in range(len(input_tensor_list))
663
            ]
664
            # scatter on dim > 0 would generate non-contiguous tensor, verify that works
665
            mesh_all_to_all(output_tensor_list, input_tensor_list, mesh, mesh_dim=0)
666
            output_tensor = torch.cat(output_tensor_list, dim=scatter_dim)
667
            expected_tensor = torch.cat(expected_tensor_list, dim=scatter_dim)
668

669
            self.assertEqual(output_tensor, expected_tensor)
670

671
    @with_comms
672
    @run_with_both_funcol_impls
673
    def test_all_to_all_nd(self):
674
        mesh_tensor = torch.arange(8).reshape(2, 2, 2)
675
        mesh = DeviceMesh(self.device_type, mesh_tensor)
676
        tensor_shape = [3, 3, 3]
677
        # check all dim groups
678
        dim_to_subgroups = mesh.get_group()
679
        for dim, dim_group in enumerate(dim_to_subgroups):
680
            my_coordinate = mesh.get_coordinate()[dim]
681
            dim_group_size = get_world_size(dim_group)
682
            global_ranks = [
683
                get_global_rank(dim_group, i) for i in range(dim_group_size)
684
            ]
685
            input_tensor_list = [
686
                torch.ones(*tensor_shape, device=self.device_type)
687
                * (i + self.rank * dim_group_size)
688
                for i in range(dim_group_size)
689
            ]
690
            expected_tensor_list = [
691
                torch.ones(*tensor_shape, device=self.device_type)
692
                * (my_coordinate + global_rank * dim_group_size)  # i.e. transpose
693
                for global_rank in global_ranks
694
            ]
695
            for scatter_dim in range(len(tensor_shape)):
696
                # input_tensor = torch.cat(input_tensor_list, dim=scatter_dim)
697
                output_tensor_list = [
698
                    torch.empty_like(input_tensor_list[idx])
699
                    for idx in range(len(input_tensor_list))
700
                ]
701
                # scatter on dim > 0 would generate non-contiguous tensor, verify that works
702
                mesh_all_to_all(
703
                    output_tensor_list, input_tensor_list, mesh, mesh_dim=dim
704
                )
705
                output_tensor = torch.cat(output_tensor_list, dim=scatter_dim)
706
                expected_tensor = torch.cat(expected_tensor_list, dim=scatter_dim)
707
                self.assertEqual(output_tensor, expected_tensor)
708

709

710
if __name__ == "__main__":
711
    run_tests()
712

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.