pytorch
844 строки · 28.9 Кб
1# Owner(s): ["oncall: distributed"]
2
3import os4import sys5import unittest6from functools import partial, wraps7
8import torch9import torch.distributed as dist10import torch.distributed._functional_collectives as ft_c11import torch.distributed._tensor as dt12import torch.distributed.distributed_c10d as c10d13from functorch import make_fx14from torch._inductor.utils import run_and_get_code15from torch.testing import FileCheck16from torch.testing._internal.distributed.fake_pg import FakeStore17from torch.utils._triton import has_triton18
19
20if not dist.is_available():21print("Distributed not available, skipping tests", file=sys.stderr)22sys.exit(0)23
24from torch.testing._internal.common_distributed import (25MultiProcessTestCase,26MultiThreadedTestCase,27requires_nccl,28TEST_SKIPS,29)
30from torch.testing._internal.common_utils import (31instantiate_parametrized_tests,32parametrize,33run_tests,34TestCase,35)
36
37
38def new_subgroups(group_size: int, pg_tag=None):39world_size = dist.get_world_size()40subgroups = []41cur_subgroup = None42
43for subgroup_id in range(world_size // group_size):44start_rank = subgroup_id * group_size45end_rank = start_rank + group_size46ranks_in_subgroup = list(range(start_rank, end_rank))47subgroup = c10d._new_group_with_tag(48ranks=ranks_in_subgroup,49pg_tag=pg_tag,50)51subgroups.append(subgroup)52
53rank = dist.get_rank()54if rank in ranks_in_subgroup:55cur_subgroup = subgroup56
57return cur_subgroup, subgroups58
59
60class TestExpand(MultiThreadedTestCase):61@property62def world_size(self):63return 464
65def setUp(self):66super().setUp()67self._spawn_threads()68
69def test_expand_1d_rank_list(self):70tag, rankset, group_size = ft_c._expand_group([0, 1, 2, 3])71self.assertEqual("", tag)72self.assertEqual([0, 1, 2, 3], rankset)73self.assertEqual(4, group_size)74
75tag, rankset, group_size = ft_c._expand_group([0, 1, 2, 3], "bla")76self.assertEqual("bla", tag)77
78def test_expand_2d_rank_list(self):79tag, rankset, group_size = ft_c._expand_group([[0, 1], [2, 3]])80self.assertEqual("", tag)81self.assertEqual([0, 1, 2, 3], rankset)82self.assertEqual(2, group_size)83
84tag, rankset, group_size = ft_c._expand_group([[0, 1], [2, 3]], "blu")85self.assertEqual("blu", tag)86
87with self.assertRaisesRegex(ValueError, "group sizes must be identical"):88ft_c._expand_group([[0], [1, 2, 3]])89
90def test_expand_process_group(self):91tag, rankset, group_size = ft_c._expand_group(dist.group.WORLD)92self.assertEqual(c10d._get_group_tag(dist.group.WORLD), tag)93self.assertEqual([0, 1, 2, 3], rankset)94self.assertEqual(4, group_size)95
96tag, rankset, group_size = ft_c._expand_group(dist.group.WORLD, "bla")97self.assertEqual("bla", tag)98
99my_pg, others = new_subgroups(group_size=2)100tag, rankset, group_size = ft_c._expand_group(my_pg)101self.assertEqual(c10d._get_group_tag(my_pg), tag)102self.assertEqual(dist.get_process_group_ranks(my_pg), rankset)103self.assertEqual(2, group_size)104
105my_pg = None106for i in range(dist.get_world_size()):107group = c10d._new_group_with_tag([i], pg_tag="my_pg")108if i == dist.get_rank():109my_pg = group110tag, rankset, group_size = ft_c._expand_group(my_pg)111self.assertEqual("my_pg", tag)112self.assertEqual([dist.get_rank()], rankset)113self.assertEqual(1, group_size)114
115tag, rankset, group_size = ft_c._expand_group(my_pg, "bla")116self.assertEqual("bla", tag)117
118def test_expand_device_mesh(self):119mesh = dt.DeviceMesh("cpu", torch.arange(4))120tag, rankset, group_size = ft_c._expand_group(mesh)121self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=0)), tag)122self.assertEqual([0, 1, 2, 3], rankset)123self.assertEqual(4, group_size)124
125mesh = dt.DeviceMesh("cpu", torch.arange(4))126tag, rankset, group_size = ft_c._expand_group(mesh)127self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=0)), tag)128self.assertEqual([0, 1, 2, 3], rankset)129self.assertEqual(4, group_size)130
131def test_expand_device_mesh_tuple(self):132mesh = dt.DeviceMesh("cpu", torch.arange(4).view(2, 2))133with self.assertRaisesRegex(AssertionError, "Only 1D mesh"):134tag, rankset, group_size = ft_c._expand_group(mesh)135
136tag, rankset, group_size = ft_c._expand_group((mesh, 0))137self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=0)), tag)138expected_rankset = [0, 2] if dist.get_rank() in [0, 2] else [1, 3]139self.assertEqual(expected_rankset, rankset)140self.assertEqual(2, group_size)141
142tag, rankset, group_size = ft_c._expand_group((mesh, 1))143expected_rankset = [0, 1] if dist.get_rank() in [0, 1] else [2, 3]144self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=1)), tag)145self.assertEqual(expected_rankset, rankset)146self.assertEqual(2, group_size)147
148
149class TestPgTag(MultiThreadedTestCase):150@property151def world_size(self):152return 4153
154def setUp(self):155super().setUp()156self._spawn_threads()157
158"""159The behavior we want is as follow:
160
161- rankset+tag will always result in the same PG.
162Do we enforce this by failing creation of new PGs or returning existing ones?
163Return existing one.
164
165- default tag gives existing behavior.
166This means we should create duplicates.
167- _expand_group on _default-tagged pg should always resolve to it
168This mean we can't depend on empty tag + rankset.
169"""
170
171def test_pg_creation_with_tag(self):172my_group, _ = new_subgroups(group_size=2, pg_tag="blu")173my_group2, _ = new_subgroups(group_size=2, pg_tag="blu")174self.assertEqual(my_group, my_group2)175
176my_group3, _ = new_subgroups(group_size=2, pg_tag="blu2")177self.assertNotEqual(my_group, my_group3)178
179my_group4, _ = new_subgroups(group_size=2)180self.assertNotEqual(my_group, my_group4)181
182my_group5, _ = new_subgroups(group_size=2)183self.assertNotEqual(my_group4, my_group5)184
185def test_pg_lookup_roundtrip(self):186pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")187pg_tag1, _ = new_subgroups(group_size=2, pg_tag="blu2")188pg_notag0, _ = new_subgroups(group_size=2)189pg_notag1, _ = new_subgroups(group_size=2)190
191def roundtrip(pg):192tag, rankset, _ = ft_c._expand_group(pg)193return c10d._find_pg_by_ranks_and_tag(tag, rankset)194
195self.assertEqual(pg_tag0, roundtrip(pg_tag0))196self.assertEqual(pg_tag1, roundtrip(pg_tag1))197self.assertEqual(pg_notag0, roundtrip(pg_notag0))198self.assertEqual(pg_notag1, roundtrip(pg_notag1))199
200def test_pg_lookup_with_tag(self):201pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")202pg_tag1, _ = new_subgroups(group_size=2, pg_tag="bla")203pg_notag0, _ = new_subgroups(group_size=2)204
205def roundtrip(pg, pg_tag):206tag, rankset, _ = ft_c._expand_group(pg, pg_tag)207return c10d._find_pg_by_ranks_and_tag(tag, rankset)208
209self.assertEqual(pg_tag0, roundtrip(pg_tag1, "blu"))210self.assertEqual(pg_tag0, roundtrip(pg_notag0, "blu"))211# Cannot erase the tag of a PG212self.assertEqual(pg_tag0, roundtrip(pg_tag0, ""))213
214def test_find_or_create_pg(self):215pg = c10d._find_or_create_pg_by_ranks_and_tag("blu", [0, 1, 2, 3], 2)216pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")217self.assertEqual(pg, pg_tag0)218
219def test_find_root_pg(self):220pg = c10d._find_pg_by_ranks_and_tag("", [0, 1, 2, 3])221self.assertEqual(dist.group.WORLD, pg)222
223
224@instantiate_parametrized_tests
225class TestTraceableCollectives(MultiThreadedTestCase):226@property227def world_size(self):228return 4229
230def setUp(self):231super().setUp()232self._spawn_threads()233
234@parametrize("device", ["cpu", "cuda"])235def test_broadcast(self, device):236if device == "cuda":237if torch.cuda.device_count() < self.world_size:238self.skipTest("Not enough CUDA devices")239torch.cuda.set_device(dist.get_rank())240
241if dist.get_rank() == 0:242tensor = torch.ones([4], device=device)243else:244tensor = torch.zeros([4], device=device)245
246mesh = dt.DeviceMesh(device, torch.arange(4))247res = ft_c.broadcast(tensor, 0, mesh)248self.assertEqual(res, torch.ones([4], device=device))249
250@parametrize("device", ["cpu", "cuda"])251def test_all_reduce_eager(self, device):252if device == "cuda":253if torch.cuda.device_count() < self.world_size:254self.skipTest("Not enough CUDA devices")255torch.cuda.set_device(dist.get_rank())256
257tensor = torch.ones([4], device=device)258mesh = dt.DeviceMesh(device, torch.arange(4))259
260res = ft_c.all_reduce(tensor, "sum", mesh)261self.assertEqual(res, torch.tensor([4, 4, 4, 4], dtype=torch.float))262
263mesh = dt.DeviceMesh(device, torch.arange(4).view(2, 2))264res2 = ft_c.all_reduce(tensor, "sum", (mesh, 1))265self.assertEqual(res2, torch.tensor([2, 2, 2, 2], dtype=torch.float))266
267@parametrize("device", ["cpu", "cuda"])268def test_all_reduce_coalesced_eager(self, device):269if device == "cuda":270if torch.cuda.device_count() < self.world_size:271self.skipTest("Not enough CUDA devices")272torch.cuda.set_device(dist.get_rank())273
274t0 = torch.ones([4], device=device)275t1 = torch.ones([6], device=device) + 2276mesh = dt.DeviceMesh(device, torch.arange(4))277
278res = ft_c.all_reduce_coalesced([t0, t1], "sum", mesh)279self.assertEqual(res[0], t0 * 4)280self.assertEqual(res[1], t1 * 4)281
282@parametrize("device", ["cpu", "cuda"])283def test_all_gather_tensor(self, device):284if device == "cuda":285if torch.cuda.device_count() < self.world_size:286self.skipTest("Not enough CUDA devices")287torch.cuda.set_device(dist.get_rank())288
289# testing 1d/2d mesh290mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size))291mesh_2d = dt.DeviceMesh(device, torch.arange(self.world_size).view(2, 2))292for mesh in [mesh_1d, mesh_2d]:293dims_to_gather = [0, 1, 2]294for dim in dims_to_gather:295output_size = [3, 3, 3]296output_size[dim] *= mesh.size(0)297# each rank have its own tensor, all_gather gives a bigger tensor298local_tensor = torch.ones([3, 3, 3], device=device)299gathered_tensor = ft_c.all_gather_tensor(300local_tensor, gather_dim=dim, group=(mesh, 0)301)302self.assertEqual(gathered_tensor, torch.ones(output_size))303
304@parametrize("device", ["cpu", "cuda"])305def test_all_gather_into_tensor_coalesced(self, device):306if device == "cuda":307if torch.cuda.device_count() < self.world_size:308self.skipTest("Not enough CUDA devices")309torch.cuda.set_device(dist.get_rank())310
311tensors = [torch.ones([4], device=device), torch.ones([4], device=device) + 1]312mesh = dt.DeviceMesh(device, torch.arange(4))313
314res = ft_c.all_gather_into_tensor_coalesced(tensors, mesh)315self.assertEqual(2, len(res))316self.assertEqual(torch.ones([4 * dist.get_world_size()], device=device), res[0])317self.assertEqual(318torch.ones([4 * dist.get_world_size()], device=device) + 1, res[1]319)320
321@parametrize("device", ["cpu", "cuda"])322def test_reduce_scatter_tensor(self, device):323if device == "cuda":324if torch.cuda.device_count() < self.world_size:325self.skipTest("Not enough CUDA devices")326torch.cuda.set_device(dist.get_rank())327
328# testing 1d/2d mesh329mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size))330mesh_2d = dt.DeviceMesh(device, torch.arange(self.world_size).view(2, 2))331for mesh in [mesh_1d, mesh_2d]:332dims_to_scatter = [0, 1]333for dim in dims_to_scatter:334group_size = mesh.size(0)335input_size = [3, 3]336output_size = [3, 3]337output_size[dim] *= group_size338input_tensor = torch.ones(output_size, device=device)339res_num = 1 * group_size340rs_tensor = ft_c.reduce_scatter_tensor(341input_tensor, "sum", scatter_dim=dim, group=(mesh, 0)342)343self.assertEqual(rs_tensor, torch.ones(input_size) * res_num)344
345@parametrize("device", ["cpu", "cuda"])346def test_reduce_scatter_into_tensor_coalesced(self, device):347if device == "cuda":348if torch.cuda.device_count() < self.world_size:349self.skipTest("Not enough CUDA devices")350torch.cuda.set_device(dist.get_rank())351tensors = [352torch.ones([4], dtype=torch.int64, device=device),353torch.ones([4], dtype=torch.int64, device=device) + 1,354]355mesh = dt.DeviceMesh(device, torch.arange(4))356
357res = ft_c.reduce_scatter_tensor_coalesced(tensors, "sum", [0, 0], mesh)358self.assertEqual(2, len(res))359self.assertEqual(torch.tensor([4], device=device), res[0])360self.assertEqual(torch.tensor([8], device=device), res[1])361
362
363class TestMetaCollectives(TestCase):364def test_all_reduce(self):365x = torch.rand((2, 3, 4), device="meta")366out = ft_c.all_reduce(x, "sum", "0")367self.assertEqual(x.size(), out.size())368
369
370class TestGradCollectives(MultiThreadedTestCase):371@property372def world_size(self):373return 2374
375def setUp(self):376super().setUp()377self._spawn_threads()378
379def test_all_reduce(self):380x = torch.rand([4], requires_grad=True)381y = torch.rand([4], requires_grad=True)382out = ft_c.all_reduce(x, "sum", dist.group.WORLD)383(out + y).sum().backward()384self.assertIsNone(x.grad)385
386
387class TestMakeFx(TestCase):388def setUp(self):389# make_fx is not thread-safe due to patching nd mutating global states390# so create a fake_pg.391self.rank = 0392self.world_size = 2393store = FakeStore()394dist.init_process_group(395backend="fake",396world_size=self.world_size,397rank=self.rank,398store=store,399)400
401def tearDown(self):402super().tearDown()403
404self.assertFalse(torch.fx._symbolic_trace.is_fx_tracing())405
406def test_all_reduce_tracing(self):407def allred(input):408return ft_c.all_reduce(input, "sum", group=dist.group.WORLD) + 1409
410graph = make_fx(allred)(torch.rand(4))411FileCheck().check("all_reduce").check("wait_tensor").run(str(graph.graph))412
413mesh = dt.DeviceMesh("cpu", torch.arange(self.world_size))414
415def allred_mesh(input):416return ft_c.all_reduce(input, "sum", mesh) + 1417
418mesh_graph = make_fx(allred_mesh)(torch.rand(4))419FileCheck().check_not("get_attr").check("wait_tensor").run(420str(mesh_graph.graph)421)422
423def allred_mesh_dim(input):424return ft_c.all_reduce(input, "sum", (mesh, 0)) + 1425
426mesh_dim_graph = make_fx(allred_mesh_dim)(torch.rand(4))427FileCheck().check_not("get_attr").check("wait_tensor").run(428str(mesh_dim_graph.graph)429)430
431
432BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO433WORLD_SIZE = 2434
435
436def exit_if_lt_x_gpu(x):437if torch.cuda.device_count() < x:438sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)439
440
441def with_comms(func=None):442if func is None:443return partial(444with_comms,445)446
447@wraps(func)448def wrapper(self, *args, **kwargs):449global BACKEND450
451if "BACKEND" in os.environ:452BACKEND = os.environ["BACKEND"]453if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size:454sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)455self.dist_init()456func(self)457self.destroy_comms()458
459return wrapper460
461
462class TestCollectivesWithNCCL(MultiProcessTestCase):463def setUp(self):464super().setUp()465os.environ["WORLD_SIZE"] = str(self.world_size)466os.environ["BACKEND"] = dist.Backend.NCCL467BACKEND = dist.Backend.NCCL468self._spawn_processes()469
470@property471def device(self):472return torch.device(self.rank)473
474@property475def world_size(self):476return WORLD_SIZE477
478@property479def process_group(self):480return dist.group.WORLD481
482def dist_init(self):483dist.init_process_group(484backend=BACKEND,485world_size=self.world_size,486rank=self.rank,487init_method=f"file://{self.file_name}",488)489
490# set device for nccl pg for collectives491if BACKEND == "nccl":492torch.cuda.set_device(self.rank)493
494def destroy_comms(self):495# Wait for all ranks to reach here before starting shutdown.496dist.barrier()497dist.destroy_process_group()498
499@requires_nccl()500@with_comms()501def test_all_gather_into_tensor_coalesced(self):502exit_if_lt_x_gpu(self.world_size)503
504tensors = [505torch.ones([4], device=f"cuda:{self.rank}"),506torch.ones([4], device=f"cuda:{self.rank}") + 1,507]508mesh = dt.DeviceMesh(f"cuda:{self.rank}", torch.arange(self.world_size))509
510res = ft_c.all_gather_into_tensor_coalesced(tensors, mesh)511self.assertEqual(2, len(res))512self.assertEqual(torch.ones([4 * dist.get_world_size()]), res[0])513self.assertEqual(torch.ones([4 * dist.get_world_size()]) + 1, res[1])514
515@with_comms()516def test_all_to_all_single(self):517device = "cuda" if BACKEND == dist.Backend.NCCL else "cpu"518mesh = dt.DeviceMesh(device, torch.arange(self.world_size))519rank = dist.get_rank()520
521row = self.world_size * (rank + 1) * (self.world_size + 1) / 2522x = torch.ones(int(row), 5, device=device) * (rank + 1)523split_sizes = [(i + 1) * (rank + 1) for i in range(self.world_size)]524y = ft_c.all_to_all_single(525x, output_split_sizes=split_sizes, input_split_sizes=split_sizes, group=mesh526)527expected = []528for idx, tensor in enumerate(torch.split(x, split_sizes)):529expected.append(torch.full_like(tensor, (idx + 1)))530expected = torch.cat(expected)531self.assertEqual(y, expected)532
533@with_comms()534def test_all_to_all_single_1d_input(self):535device = "cuda" if BACKEND == dist.Backend.NCCL else "cpu"536mesh = dt.DeviceMesh(device, torch.arange(self.world_size))537rank = dist.get_rank()538
539row = self.world_size * (rank + 1) * (self.world_size + 1) / 2540x = torch.ones(int(row), device=device) * (rank + 1)541split_sizes = [(i + 1) * (rank + 1) for i in range(self.world_size)]542y = ft_c.all_to_all_single(543x, output_split_sizes=split_sizes, input_split_sizes=split_sizes, group=mesh544)545expected = []546for idx, tensor in enumerate(torch.split(x, split_sizes)):547expected.append(torch.full_like(tensor, (idx + 1)))548expected = torch.cat(expected)549self.assertEqual(y, expected)550
551@with_comms()552def test_all_to_all_single_split_sizes_none(self):553device = "cuda" if BACKEND == dist.Backend.NCCL else "cpu"554mesh = dt.DeviceMesh(device, torch.arange(self.world_size))555rank = dist.get_rank()556
557x = torch.ones(self.world_size, self.world_size, device=device) * (rank + 1)558y = ft_c.all_to_all_single(559x, output_split_sizes=None, input_split_sizes=None, group=mesh560)561expected = []562for idx, tensor in enumerate(torch.chunk(x, self.world_size)):563expected.append(torch.full_like(tensor, (idx + 1)))564expected = torch.cat(expected)565self.assertEqual(y, expected)566
567@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")568@requires_nccl()569@with_comms()570def test_tracing(self):571def allreduce(t, pg):572return ft_c.all_reduce(t, "sum", pg)573
574compiled_allreduce = torch.compile(allreduce, fullgraph=True)575compiled_allreduce(torch.randn(8, device=self.device), self.process_group)576
577@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")578def test_tracing_with_fakepg(self):579exit_if_lt_x_gpu(self.world_size)580
581def allreduce(t, pg):582return ft_c.all_reduce(t, "sum", pg)583
584compiled_allreduce = torch.compile(allreduce, fullgraph=True)585dist.init_process_group(586backend="fake",587rank=0,588world_size=8,589store=FakeStore(),590)591allreduce(torch.randn(8, device=self.device), pg=dist.group.WORLD)592
593@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")594@requires_nccl()595@with_comms()596def test_tracing_with_dce_code(self):597if self.world_size > 2:598return599
600def func(batch, group, rank):601ret = ft_c.permute_tensor(batch, [1, 0], group)602if hasattr(ret, "wait"):603ret = ret.wait()604if rank == 0:605return ret606else:607return batch * 5608
609compiled_func = torch.compile(func)610ret = compiled_func(611torch.ones((100,), device="cuda"), self.process_group, self.rank612)613dist.barrier()614
615
616class TestNCCLCollectivesWithWorldSize4(TestCollectivesWithNCCL):617@property618def world_size(self):619return 4620
621@requires_nccl()622@with_comms()623def test_permute_tensor_with_sub_group(self):624exit_if_lt_x_gpu(self.world_size)625
626device = "cuda"627mesh_dim_names = ["dp", "tp"]628
629mesh_2d = dt.init_device_mesh(630device, (2, self.world_size // 2), mesh_dim_names=mesh_dim_names631)632
633for mesh_name in mesh_dim_names:634mesh = mesh_2d[mesh_name]635rank = mesh.get_local_rank()636
637# rank0: [0., 1.], rank1: [2., 3.]638send_tensor = torch.arange(2, dtype=torch.float32, device=device) + 2 * rank639recvd_tensor = ft_c.permute_tensor(send_tensor, [1, 0], group=mesh)640
641# rank0: [2., 3.], rank1: [0., 1.]642expected = torch.arange(2, dtype=torch.float32, device=device) + 2 * (643(rank - 1 + 2) % 2644)645self.assertEqual(646recvd_tensor,647expected,648msg=f"Expected {expected} on {self.rank=} (local_rank={rank}), "649f"but received {recvd_tensor} instead.",650)651
652
653@instantiate_parametrized_tests
654class TestFunctionalAutograd(MultiThreadedTestCase):655def setUp(self):656super().setUp()657self._spawn_threads()658
659@property660def world_size(self):661return 2662
663@parametrize("compile", [True, False])664def test_all_to_all_single(self, compile: bool = True) -> None:665group = dist.group.WORLD.group_name666
667t = torch.ones((self.world_size, 2), requires_grad=True)668
669def my_func(t: torch.Tensor, world_size: int) -> torch.Tensor:670sizes = [1] * world_size671t = t * 2672assert t.requires_grad673out = ft_c.all_to_all_single_autograd(t, sizes, sizes, group)674out = out + 0675return out676
677if compile:678compiled = torch.compile(my_func, fullgraph=True, backend="aot_eager")679else:680compiled = my_func681
682out = compiled(t, self.world_size)683self.assertEqual(out.shape, t.shape)684self.assertEqual(out, torch.full_like(t, 2.0))685self.assertIsNotNone(out.grad_fn)686self.assertTrue(out.requires_grad)687loss = out.sum()688loss.backward()689self.assertEqual(t.grad, torch.full_like(t, 2.0))690
691def test_all_to_all_single_inductor(self) -> None:692group = dist.group.WORLD.group_name693
694t = torch.rand((self.world_size, 2), requires_grad=True)695
696def my_func(t: torch.Tensor, world_size: int) -> torch.Tensor:697sizes = [1] * world_size698t = t * 10699assert t.requires_grad700out = ft_c.all_to_all_single_autograd(t, sizes, sizes, group)701out = out + 2702return out.sum()703
704compiled = torch.compile(my_func, fullgraph=True)705
706def run_with_backward():707out = compiled(t, self.world_size)708out.backward()709
710res, codes = run_and_get_code(run_with_backward)711for code in codes:712FileCheck().check_count(713"_c10d_functional.all_to_all_single.default", 1, exactly=True714).check_count("_c10d_functional.wait_tensor.default", 1, exactly=True).run(715code
716)717
718self.assertIsNotNone(t.grad)719
720@parametrize("compile", [True, False])721def test_all_gather_tensor(self, compile: bool) -> None:722group = dist.group.WORLD.group_name723
724def my_func(t: torch.Tensor, dim: int) -> torch.Tensor:725assert t.requires_grad726out = ft_c.all_gather_tensor_autograd(727t * 1.0,728gather_dim=dim,729group=group,730)731out = out * 1.0732return out733
734if compile:735compiled = torch.compile(my_func, fullgraph=True, backend="aot_eager")736else:737compiled = my_func738
739dims_to_gather = [0, 1, 2]740for dim in dims_to_gather:741output_size = [3, 3, 3]742output_size[dim] *= self.world_size743# each rank have its own tensor, all_gather gives a bigger tensor744local_tensor = torch.ones([3, 3, 3], requires_grad=True)745gathered_tensor = compiled(local_tensor, dim)746self.assertEqual(gathered_tensor, torch.ones(output_size))747
748gathered_tensor.sum().backward()749self.assertEqual(750local_tensor.grad,751torch.full((3, 3, 3), fill_value=float(self.world_size)),752)753
754@parametrize("compile", [True, False])755def test_reduce_scatter_tensor(self, compile: bool) -> None:756group = dist.group.WORLD.group_name757
758def my_func(t: torch.Tensor, dim: int) -> torch.Tensor:759assert t.requires_grad760rs_tensor = (761ft_c.reduce_scatter_tensor_autograd(762input_tensor * 1.0, "sum", scatter_dim=dim, group=group763)764* 1.0765)766return rs_tensor767
768if compile:769compiled = torch.compile(my_func, fullgraph=True, backend="aot_eager")770else:771compiled = my_func772
773dims_to_scatter = [0, 1]774for dim in dims_to_scatter:775group_size = self.world_size776input_size = [3, 3]777output_size = [3, 3]778output_size[dim] *= group_size779input_tensor = torch.ones(output_size, requires_grad=True)780rs_tensor = compiled(input_tensor, dim)781res_num = 1 * group_size782self.assertEqual(rs_tensor, torch.ones(input_size) * res_num)783rs_tensor.sum().backward()784self.assertEqual(input_tensor.grad, torch.full(output_size, fill_value=1.0))785
786
787class TestFunctionalAutogradWithNCCL(MultiProcessTestCase):788def setUp(self):789super().setUp()790os.environ["WORLD_SIZE"] = str(self.world_size)791os.environ["BACKEND"] = dist.Backend.NCCL792self._spawn_processes()793
794@property795def device(self):796return torch.device(self.rank)797
798@property799def world_size(self):800return 2801
802@property803def process_group(self):804return dist.group.WORLD805
806def dist_init(self):807dist.init_process_group(808backend=BACKEND,809world_size=self.world_size,810rank=self.rank,811init_method=f"file://{self.file_name}",812)813
814# set device for nccl pg for collectives815if BACKEND == "nccl":816torch.cuda.set_device(self.rank)817
818def destroy_comms(self):819# Wait for all ranks to reach here before starting shutdown.820dist.barrier()821dist.destroy_process_group()822
823@requires_nccl()824@with_comms()825def test_all_to_all_single(self) -> None:826group = self.process_group.group_name827
828t = torch.ones((self.world_size, 2), requires_grad=True, device=self.device)829
830sizes = [1] * self.world_size831assert t.requires_grad832out = ft_c.all_to_all_single_autograd(t * 2, sizes, sizes, group) + 0833
834self.assertEqual(out.shape, t.shape)835self.assertEqual(out, torch.full_like(t, 2.0))836self.assertIsNotNone(out.grad_fn)837self.assertTrue(out.requires_grad)838loss = out.sum()839loss.backward()840self.assertEqual(t.grad, torch.full_like(t, 2.0))841
842
843if __name__ == "__main__":844run_tests()845