pytorch
711 строк · 26.8 Кб
1# Copyright (c) Meta Platforms, Inc. and affiliates
2# Owner(s): ["oncall: distributed"]
3import os4
5import torch6import torch.distributed._functional_collectives as funcol7from torch.distributed._tensor import DTensor8from torch.distributed._tensor._collective_utils import (9mesh_all_to_all,10mesh_broadcast,11mesh_scatter,12)
13from torch.distributed._tensor.placement_types import _Partial, Shard14from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh15
16from torch.distributed.distributed_c10d import (17get_global_rank,18get_world_size,19init_process_group,20is_initialized,21is_nccl_available,22ProcessGroup,23)
24from torch.testing._internal.common_distributed import run_with_both_funcol_impls25from torch.testing._internal.common_utils import (26instantiate_parametrized_tests,27run_tests,28)
29from torch.testing._internal.distributed._tensor.common_dtensor import (30DTensorTestBase,31skip_unless_torch_gpu,32with_comms,33)
34from torch.testing._internal.distributed.fake_pg import FakeStore35
36
37def _get_device_type(world_size):38if (39torch.cuda.is_available()40and torch.cuda.device_count() >= world_size41and is_nccl_available()42):43device_type = "cuda"44else:45device_type = "cpu"46return device_type47
48
49def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0):50os.environ["MASTER_ADDR"] = addr51os.environ["MASTER_PORT"] = port52os.environ["WORLD_SIZE"] = f"{world_size}"53os.environ["RANK"] = f"{rank}"54
55
56@instantiate_parametrized_tests
57class DeviceMeshTest(DTensorTestBase):58@property59def world_size(self):60return 461
62@run_with_both_funcol_impls63def test_init_process_group(self):64device_type = _get_device_type(self.world_size)65mesh_tensor = torch.arange(4).reshape(2, 2)66self.assertTrue(not is_initialized())67_set_env_var(world_size=self.world_size, rank=self.rank)68DeviceMesh(device_type, mesh_tensor)69self.assertTrue(is_initialized())70self.destroy_pg()71
72@with_comms73@skip_unless_torch_gpu74def test_assert_invalid_mesh_tensor(self):75mesh = torch.arange(self.world_size).to(self.rank)76with self.assertRaises(ValueError):77device_mesh = DeviceMesh(self.device_type, mesh)78
79@with_comms80@run_with_both_funcol_impls81def test_get_group(self):82mesh_shape = (2, self.world_size // 2)83mesh_2d = init_device_mesh(84self.device_type, mesh_shape, mesh_dim_names=("dp", "tp")85)86
87tp_mesh = mesh_2d["tp"]88dp_mesh = mesh_2d["dp"]89
90self.assertEqual(len(mesh_2d.get_group()), 2)91self.assertEqual(mesh_2d.get_group()[0], mesh_2d.get_group("dp"))92self.assertEqual(mesh_2d.get_group()[1], mesh_2d.get_group("tp"))93
94self.assertEqual(mesh_2d.get_group(0), mesh_2d.get_group("dp"))95self.assertEqual(mesh_2d.get_group(1), mesh_2d.get_group("tp"))96
97self.assertEqual(mesh_2d.get_group("dp"), dp_mesh.get_group())98self.assertEqual(mesh_2d.get_group("tp"), tp_mesh.get_group())99
100@with_comms101@run_with_both_funcol_impls102def test_get_local_rank_raises_exception(self):103mesh_shape = (2, self.world_size // 2)104mesh_2d = init_device_mesh(105self.device_type, mesh_shape, mesh_dim_names=("dp", "tp")106)107
108with self.assertRaisesRegex(109RuntimeError,110"Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",111):112local_rank = mesh_2d.get_local_rank()113
114@with_comms115@run_with_both_funcol_impls116def test_get_local_rank(self):117mesh_shape = (2, self.world_size // 2)118mesh_2d = init_device_mesh(119self.device_type, mesh_shape, mesh_dim_names=("dp", "tp")120)121self.assertEqual(mesh_2d.get_local_rank("dp"), mesh_2d.get_local_rank(0))122self.assertEqual(mesh_2d.get_local_rank("tp"), mesh_2d.get_local_rank(1))123
124dp_mesh = mesh_2d["dp"]125tp_mesh = mesh_2d["tp"]126self.assertEqual(dp_mesh.get_local_rank(), mesh_2d.get_local_rank("dp"))127self.assertEqual(tp_mesh.get_local_rank(), mesh_2d.get_local_rank("tp"))128
129@with_comms130@run_with_both_funcol_impls131def test_device_mesh_2d(self):132mesh_tensor = torch.arange(4).reshape(2, 2)133# construct a cuda device mesh134mesh = DeviceMesh(self.device_type, mesh_tensor)135
136# check all dim groups137dim_to_subgroups = mesh.get_group()138
139expected_ranks_by_dim = [[[0, 2], [1, 3]], [[0, 1], [2, 3]]]140for dim, dim_group in enumerate(dim_to_subgroups):141self.assertTrue(dim < 2)142dim_ranks = expected_ranks_by_dim[dim]143
144dim_group_size = get_world_size(dim_group)145self.assertIsInstance(dim_group, ProcessGroup)146self.assertEqual(dim_group_size, 2)147global_ranks = [148get_global_rank(dim_group, i) for i in range(dim_group_size)149]150current_rank_expected_group_ranks = (151dim_ranks[0] if self.rank in dim_ranks[0] else dim_ranks[1]152)153self.assertEqual(global_ranks, current_rank_expected_group_ranks)154
155@run_with_both_funcol_impls156def test_fake_pg_device_mesh(self):157fake_store = FakeStore()158init_process_group("fake", store=fake_store, rank=0, world_size=self.world_size)159device_type = "cuda" if torch.cuda.is_available() else "cpu"160mesh = DeviceMesh(device_type, torch.arange(self.world_size))161
162local_tensor = torch.randn(2, 8)163global_tensor = funcol.all_gather_tensor(164local_tensor, gather_dim=0, group=(mesh, 0)165)166self.assertEqual(global_tensor.shape, (self.world_size * 2, 8))167
168
169class DeviceMeshTestNDim(DTensorTestBase):170@property171def world_size(self):172return 8173
174@with_comms175@run_with_both_funcol_impls176def test_device_mesh_nd(self):177# construct a cuda device mesh178mesh_tensor = torch.arange(8).reshape(2, 2, 2)179mesh = DeviceMesh(self.device_type, mesh_tensor)180
181# check all dim groups182dim_to_subgroups = mesh.get_group()183
184for dim, dim_group in enumerate(dim_to_subgroups):185self.assertTrue(dim < mesh_tensor.ndim)186dim_ranks = mesh_tensor.swapdims(-1, dim).reshape(-1, 2)187
188dim_group_size = get_world_size(dim_group)189self.assertIsInstance(dim_group, ProcessGroup)190self.assertEqual(dim_group_size, 2)191global_ranks = [192get_global_rank(dim_group, i) for i in range(dim_group_size)193]194for ranks in dim_ranks:195if self.rank in ranks:196self.assertEqual(global_ranks, ranks.tolist())197
198@with_comms199@run_with_both_funcol_impls200def test_device_mesh_hash(self):201mesh_tensor_2d = torch.arange(8).reshape(4, 2)202mesh = DeviceMesh(self.device_type, mesh_tensor_2d)203mesh2 = DeviceMesh(self.device_type, mesh_tensor_2d)204self.assertNotEqual(hash(mesh), hash(mesh2))205mesh_tensor_3d = torch.arange(8).reshape(2, 2, 2)206mesh3 = DeviceMesh(self.device_type, mesh_tensor_3d)207self.assertNotEqual(hash(mesh), hash(mesh3))208self.assertNotEqual(hash(mesh2), hash(mesh3))209
210
211class InitDeviceMeshTest(DTensorTestBase):212@property213def world_size(self):214return 8215
216@with_comms217@run_with_both_funcol_impls218def test_init_device_mesh(self):219mesh_shape = (2, 4)220ref_mesh = DeviceMesh(self.device_type, torch.arange(8).view(mesh_shape))221
222# test init_device_mesh with mesh_dim_names223mesh_dim_names = ("DP", "TP")224mesh_2d = init_device_mesh(225self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names226)227self.assertEqual(mesh_2d, ref_mesh)228self.assertEqual(mesh_2d.mesh_dim_names, mesh_dim_names)229
230# test init_device_mesh without mesh_dim_names231mesh_2d = init_device_mesh(self.device_type, mesh_shape)232self.assertEqual(mesh_2d, ref_mesh)233
234@with_comms235@run_with_both_funcol_impls236def test_raises_duplicate_mesh_dim_names(self):237with self.assertRaisesRegex(238RuntimeError,239"Each mesh_dim_name must be unique.",240):241mesh = init_device_mesh(242self.device_type,243(2, 4),244mesh_dim_names=["dp", "dp"],245)246
247@with_comms248@run_with_both_funcol_impls249def test_raises_mesh_shape_mesh_dim_names_mismatch(self):250with self.assertRaisesRegex(251RuntimeError,252"mesh_shape and mesh_dim_names should have same length!",253):254mesh = init_device_mesh(255self.device_type,256(8,),257mesh_dim_names=["dp", "tp"],258)259
260
261@instantiate_parametrized_tests
262class TestDeviceMeshGetItem(DTensorTestBase):263@property264def world_size(self):265return 8266
267@with_comms268@run_with_both_funcol_impls269def test_raises_no_mesh_dim_found(self):270with self.assertRaisesRegex(KeyError, "No `mesh_dim_names` found."):271mesh = init_device_mesh(self.device_type, (2, 4))272child_mesh = mesh["DP"]273
274@with_comms275@run_with_both_funcol_impls276def test_raises_invalid_mesh_dim_name(self):277child_mesh_dim_name = "PP"278with self.assertRaisesRegex(279KeyError, f"Mesh dimension '{child_mesh_dim_name}' does not exist."280):281mesh_dim_names = ("DP", "TP")282mesh = init_device_mesh(283self.device_type, (2, 4), mesh_dim_names=mesh_dim_names284)285child_mesh = mesh[child_mesh_dim_name]286
287@with_comms288@run_with_both_funcol_impls289def test_get_item(self):290mesh_shape = (2, 4)291mesh_dim_names = ("DP", "TP")292mesh_2d = init_device_mesh(293self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names294)295
296pg_ranks_by_dim_name = {}297for mesh_dim_name in mesh_dim_names:298mesh_dim = mesh_dim_names.index(mesh_dim_name)299pg_ranks_by_dim_name[mesh_dim_name] = mesh_2d.mesh.swapdims(300-1, mesh_dim301).reshape(-1, mesh_2d.mesh.size(mesh_dim))302
303tp_mesh = mesh_2d["TP"]304tp_group_idx = self.rank // 4305self.assertEqual(tp_mesh.mesh, pg_ranks_by_dim_name["TP"][tp_group_idx])306
307dp_mesh = mesh_2d["DP"]308dp_group_idx = self.rank % 4309self.assertEqual(mesh_2d["DP"].mesh, pg_ranks_by_dim_name["DP"][dp_group_idx])310
311@with_comms312@run_with_both_funcol_impls313def test_get_item_1d(self):314mesh = init_device_mesh(self.device_type, (8,), mesh_dim_names=("dp",))315# Make sure slicing out 1D mesh from a 1D mesh works.316# We are just dummy return without the parent mesh here.317dp_mesh = mesh["dp"]318self.assertEqual(dp_mesh, mesh)319
320with self.assertRaisesRegex(RuntimeError, "Invalid mesh_dim_name"):321dp_mesh = mesh["dim0"]322
323
324@instantiate_parametrized_tests
325class TestMeshEnv(DTensorTestBase):326@with_comms327@run_with_both_funcol_impls328def test_get_parent_mesh(self):329mesh_shape = (2, self.world_size // 2)330mesh_dim_names = ("DP", "TP")331mesh_2d = init_device_mesh(332self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names333)334
335self.assertEqual(_mesh_resources.get_parent_mesh(mesh_2d["DP"]), mesh_2d)336self.assertEqual(_mesh_resources.get_parent_mesh(mesh_2d["TP"]), mesh_2d)337
338mesh_0_2 = DeviceMesh(self.device_type, [0, 2])339mesh_1_3 = DeviceMesh(self.device_type, [1, 3])340
341self.assertEqual(_mesh_resources.get_parent_mesh(mesh_2d["DP"]), mesh_2d)342self.assertEqual(_mesh_resources.get_parent_mesh(mesh_2d["TP"]), mesh_2d)343self.assertEqual(_mesh_resources.get_parent_mesh(mesh_0_2), None)344self.assertEqual(_mesh_resources.get_parent_mesh(mesh_1_3), None)345
346@with_comms347@run_with_both_funcol_impls348def test_get_parent_mesh_dim_exist(self):349mesh_shape = (2, self.world_size // 2)350mesh_dim_names = ("DP", "TP")351mesh_2d = init_device_mesh(352self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names353)354
355self.assertEqual(_mesh_resources.get_parent_mesh_dim(mesh_2d["DP"]), 0)356self.assertEqual(_mesh_resources.get_parent_mesh_dim(mesh_2d["TP"]), 1)357
358@with_comms359@run_with_both_funcol_impls360def test_get_parent_mesh_dim_not_exist(self):361mesh_shape = (self.world_size,)362mesh = init_device_mesh(self.device_type, mesh_shape)363
364self.assertEqual(_mesh_resources.get_parent_mesh_dim(mesh), None)365
366@with_comms367@run_with_both_funcol_impls368def test_get_mesh_dim_by_name(self):369mesh_shape = (2, self.world_size // 2)370mesh_dim_names = ("DP", "TP")371mesh_2d = init_device_mesh(372self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names373)374
375self.assertEqual(_mesh_resources.get_mesh_dim_by_name(mesh_2d, "DP"), 0)376self.assertEqual(_mesh_resources.get_mesh_dim_by_name(mesh_2d, "TP"), 1)377
378
379@instantiate_parametrized_tests
380class DeviceMeshCollectiveTest(DTensorTestBase):381@property382def world_size(self):383return 8384
385@with_comms386@run_with_both_funcol_impls387def test_broadcast_1d(self):388mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))389local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank390mesh_broadcast(local_tensor, mesh, mesh_dim=0)391self.assertEqual(local_tensor, torch.zeros(3, 3))392
393@with_comms394@run_with_both_funcol_impls395def test_scatter_1d(self):396mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))397scatter_tensor_shape = [3, 3, 3]398for scatter_dim in range(len(scatter_tensor_shape)):399shard_placement = Shard(scatter_dim)400scatter_tensor_shape[scatter_dim] *= self.world_size401# make the random seed same across rank402torch.manual_seed(0)403global_tensor = torch.randn(scatter_tensor_shape, device=self.device_type)404splitted_list, _ = shard_placement._split_tensor(405global_tensor, mesh.size(), with_padding=True, contiguous=True406)407recv_tensor = torch.empty_like(splitted_list[mesh.get_rank()])408# scatter on dim > 0 would generate non-contiguous tensor, verify that works409mesh_scatter(recv_tensor, splitted_list, mesh, mesh_dim=0)410self.assertEqual(recv_tensor, splitted_list[mesh.get_rank()])411
412@with_comms413@run_with_both_funcol_impls414def test_scatter_uneven(self):415device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))416my_rank = device_mesh.get_rank()417tensor_to_split = torch.randn(418device_mesh.size() + 3, device_mesh.size() + 1, device=self.device_type419)420
421for shard_dim in range(tensor_to_split.ndim):422shard_placement = Shard(shard_dim)423
424tensor_to_scatter = tensor_to_split.clone()425tensor_splitted_list = list(426torch.chunk(tensor_to_split, self.world_size, dim=shard_dim)427)428for _ in range(self.world_size - len(tensor_splitted_list)):429tensor_splitted_list.append(torch.tensor([], device=self.device_type))430
431padded_tensor_list, pad_sizes = shard_placement._split_tensor(432tensor_to_scatter,433device_mesh.size(),434with_padding=True,435contiguous=True,436)437
438scattered_tensor = torch.empty_like(padded_tensor_list[my_rank])439mesh_scatter(scattered_tensor, padded_tensor_list, device_mesh, mesh_dim=0)440
441if pad_sizes[my_rank] != 0:442scattered_tensor = shard_placement._unpad_tensor(443scattered_tensor, pad_sizes[my_rank]444)445
446if scattered_tensor.numel() == 0:447# We need to check numel() instead of size if a tensor is ([]) after unpadding,448# since the size could be ([0, 8]) after unpadding.449self.assertEqual(450scattered_tensor.numel(), tensor_splitted_list[my_rank].numel()451)452else:453self.assertEqual(454scattered_tensor.size(), tensor_splitted_list[my_rank].size()455)456self.assertEqual(scattered_tensor, tensor_splitted_list[my_rank])457
458@with_comms459@run_with_both_funcol_impls460def test_all_gather_uneven(self):461device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))462my_rank = device_mesh.get_rank()463tensor_to_split = torch.ones(464device_mesh.size() + 3,465device_mesh.size() + 1,466device=self.device_type,467)468
469for shard_dim in range(tensor_to_split.ndim):470shard_placement = Shard(shard_dim)471tensor_padded_list, pad_sizes = shard_placement._split_tensor(472tensor_to_split,473device_mesh.size(),474with_padding=True,475contiguous=True,476)477local_tensor = tensor_padded_list[my_rank]478big_tensor = funcol.all_gather_tensor(479local_tensor, gather_dim=shard_dim, group=(device_mesh, 0)480)481big_tensor_chunks = list(482torch.chunk(big_tensor, device_mesh.size(), dim=shard_dim)483)484unpadded_list = [485shard_placement._unpad_tensor(big_tensor_chunks[i], pad_sizes[i])486if pad_sizes[i] > 0487else big_tensor_chunks[i]488for i, big_tensor in enumerate(big_tensor_chunks)489]490all_gathered_tensor = torch.cat(unpadded_list, dim=shard_dim)491
492self.assertEqual(all_gathered_tensor.size(), tensor_to_split.size())493self.assertEqual(all_gathered_tensor, tensor_to_split)494
495@with_comms496@run_with_both_funcol_impls497def test_reduce_scatter_contiguous(self):498device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))499my_rank = device_mesh.get_rank()500
501# Init the tensor502step = self.world_size * 2503total_elem = step**2504tensor = torch.arange(0, total_elem).view(step, -1).to(device=self.device_type)505tensor = tensor * (my_rank + 1)506
507# Get non-contiguous tensor by slicing508tensor_to_reduce = tensor[::2, :2]509tensor_contiguous = tensor_to_reduce.clone().contiguous()510
511# Partial to Shard to trigger reduce_scatter512tensor_to_reduce = DTensor.from_local(513tensor_to_reduce, device_mesh, [_Partial()]514)515tensor_contiguous = DTensor.from_local(516tensor_contiguous, device_mesh, [_Partial()]517)518new_tensor = tensor_to_reduce.redistribute(device_mesh, [Shard(0)])519new_tensor_contiguous = tensor_contiguous.redistribute(device_mesh, [Shard(0)])520
521# The output for contiguous and non-contiguous tensors of the same value522# should return the same reducescatter value.523new_tensor_local = new_tensor._local_tensor524new_tensor_contiguous_local = new_tensor_contiguous._local_tensor525self.assertEqual(new_tensor_local, new_tensor_contiguous_local)526self.assertEqual(list(new_tensor_local.size()), [1, 2])527
528# Check the reduce numerical value529sum_base = (1 + self.world_size) * self.world_size / 2530first_elem = my_rank * sum_base * step * 2531expected_tensor = torch.tensor(532[[first_elem, first_elem + sum_base]],533dtype=new_tensor_local.dtype,534device=self.device_type,535)536self.assertEqual(new_tensor_local, expected_tensor)537
538@with_comms539@run_with_both_funcol_impls540def test_reduce_scatter_uneven(self):541device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))542my_rank = device_mesh.get_rank()543tensor_to_split = (544torch.ones(545device_mesh.size() + 3,546device_mesh.size() + 1,547device=self.device_type,548)549* self.rank550)551
552for shard_dim in range(tensor_to_split.ndim):553shard_placement = Shard(shard_dim)554tensor_to_scatter = tensor_to_split.clone()555
556tensor_splitted_list = list(557torch.chunk(tensor_to_split, self.world_size, dim=shard_dim)558)559for _ in range(self.world_size - len(tensor_splitted_list)):560tensor_splitted_list.append(torch.tensor([], device=self.device_type))561
562padded_tensor_list, pad_sizes = shard_placement._split_tensor(563tensor_to_scatter,564device_mesh.size(),565with_padding=True,566contiguous=True,567)568
569tensor_to_reduce = torch.cat(padded_tensor_list, shard_dim)570
571res_num = ((0 + self.world_size - 1) * self.world_size) / 2572
573scattered_tensor = funcol.reduce_scatter_tensor(574tensor_to_reduce,575reduceOp="sum",576scatter_dim=shard_dim,577group=(device_mesh, 0),578)579
580# unpad scattered_tensor581if pad_sizes[my_rank] > 0:582scattered_tensor = shard_placement._unpad_tensor(583scattered_tensor, pad_sizes[my_rank]584)585
586if scattered_tensor.numel() == 0:587# We need to check numel() instead of size if a tensor is ([]) after unpadding,588# since the size could be ([0, 8]) after unpadding.589self.assertEqual(590scattered_tensor.numel(), tensor_splitted_list[my_rank].numel()591)592else:593self.assertEqual(594scattered_tensor.size(), tensor_splitted_list[my_rank].size()595)596self.assertEqual(597scattered_tensor,598torch.ones_like(tensor_splitted_list[my_rank]) * res_num,599)600
601@with_comms602@run_with_both_funcol_impls603def test_broadcast_nd(self):604mesh_tensor = torch.arange(8).reshape(2, 2, 2)605mesh = DeviceMesh(self.device_type, mesh_tensor)606local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank607
608# check all dim groups609dim_to_subgroups = mesh.get_group()610for dim, dim_group in enumerate(dim_to_subgroups):611dim_group_size = get_world_size(dim_group)612global_ranks = [613get_global_rank(dim_group, i) for i in range(dim_group_size)614]615cloned_local_tensor = local_tensor.clone()616mesh_broadcast(cloned_local_tensor, mesh, mesh_dim=dim)617res_num = global_ranks[0]618self.assertEqual(cloned_local_tensor, torch.ones(3, 3) * res_num)619
620@with_comms621@run_with_both_funcol_impls622def test_scatter_nd(self):623mesh_tensor = torch.arange(8).reshape(2, 2, 2)624mesh = DeviceMesh(self.device_type, mesh_tensor)625
626# check all dim groups627dim_to_subgroups = mesh.get_group()628for dim, dim_group in enumerate(dim_to_subgroups):629dim_group_size = get_world_size(dim_group)630global_ranks = [631get_global_rank(dim_group, i) for i in range(dim_group_size)632]633scattered_tensors = [634torch.ones(3, 3, device=self.device_type) * global_rank635for global_rank in global_ranks636]637received_tensor = torch.empty_like(638scattered_tensors[mesh.get_coordinate()[dim]]639)640mesh_scatter(received_tensor, scattered_tensors, mesh, mesh_dim=dim)641self.assertEqual(received_tensor, torch.ones(3, 3) * self.rank)642
643@with_comms644@run_with_both_funcol_impls645def test_all_to_all_1d(self):646# transpose on a 2D tensor distributed over N nodes:647mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))648tensor_shape = [3, 3]649input_tensor_list = [650torch.ones(*tensor_shape, device=self.device_type)651* (rank + self.rank * self.world_size)652for rank in range(self.world_size)653]654expected_tensor_list = [655torch.ones(tensor_shape, device=self.device_type)656* (self.rank + rank * self.world_size) # i.e. transpose657for rank in range(self.world_size)658]659for scatter_dim in range(len(tensor_shape)):660output_tensor_list = [661torch.empty_like(input_tensor_list[idx])662for idx in range(len(input_tensor_list))663]664# scatter on dim > 0 would generate non-contiguous tensor, verify that works665mesh_all_to_all(output_tensor_list, input_tensor_list, mesh, mesh_dim=0)666output_tensor = torch.cat(output_tensor_list, dim=scatter_dim)667expected_tensor = torch.cat(expected_tensor_list, dim=scatter_dim)668
669self.assertEqual(output_tensor, expected_tensor)670
671@with_comms672@run_with_both_funcol_impls673def test_all_to_all_nd(self):674mesh_tensor = torch.arange(8).reshape(2, 2, 2)675mesh = DeviceMesh(self.device_type, mesh_tensor)676tensor_shape = [3, 3, 3]677# check all dim groups678dim_to_subgroups = mesh.get_group()679for dim, dim_group in enumerate(dim_to_subgroups):680my_coordinate = mesh.get_coordinate()[dim]681dim_group_size = get_world_size(dim_group)682global_ranks = [683get_global_rank(dim_group, i) for i in range(dim_group_size)684]685input_tensor_list = [686torch.ones(*tensor_shape, device=self.device_type)687* (i + self.rank * dim_group_size)688for i in range(dim_group_size)689]690expected_tensor_list = [691torch.ones(*tensor_shape, device=self.device_type)692* (my_coordinate + global_rank * dim_group_size) # i.e. transpose693for global_rank in global_ranks694]695for scatter_dim in range(len(tensor_shape)):696# input_tensor = torch.cat(input_tensor_list, dim=scatter_dim)697output_tensor_list = [698torch.empty_like(input_tensor_list[idx])699for idx in range(len(input_tensor_list))700]701# scatter on dim > 0 would generate non-contiguous tensor, verify that works702mesh_all_to_all(703output_tensor_list, input_tensor_list, mesh, mesh_dim=dim704)705output_tensor = torch.cat(output_tensor_list, dim=scatter_dim)706expected_tensor = torch.cat(expected_tensor_list, dim=scatter_dim)707self.assertEqual(output_tensor, expected_tensor)708
709
710if __name__ == "__main__":711run_tests()712