pytorch
507 строк · 16.2 Кб
1# Owner(s): ["module: c10d"]
2
3import torch4import torch.distributed as dist5from torch._C._autograd import DeviceType6from torch._C._distributed_c10d import _SymmetricMemory7from torch.distributed._symmetric_memory import (8_fused_all_gather_matmul_fallback,9_fused_all_gather_scaled_matmul_fallback,10_fused_matmul_reduce_scatter_fallback,11_fused_scaled_matmul_reduce_scatter_fallback,12enable_symm_mem_for_group,13restride_A_for_fused_matmul_reduce_scatter,14restride_A_shard_for_fused_all_gather_matmul,15)
16from torch.testing._internal.common_distributed import (17MultiProcessTestCase,18skip_if_lt_x_gpu,19)
20from torch.testing._internal.common_utils import (21instantiate_parametrized_tests,22parametrize,23run_tests,24skip_but_pass_in_sandcastle_if,25skipIfRocm,26)
27
28
29def requires_cuda_p2p_access():30cuda_p2p_access_available = (31torch.cuda.is_available()32and torch.cuda.get_device_capability() >= (8, 0)33and torch.cuda.device_count() >= 234)35num_devices = torch.cuda.device_count()36for i in range(num_devices - 1):37for j in range(i + 1, num_devices):38if not torch.cuda.can_device_access_peer(i, j):39cuda_p2p_access_available = False40break41if not cuda_p2p_access_available:42break43
44return skip_but_pass_in_sandcastle_if(45not cuda_p2p_access_available,46"cuda p2p access is not available",47)48
49
50def requires_multicast_support():51has_multicast_support = (52torch.cuda.is_available()53and _SymmetricMemory.has_multicast_support(DeviceType.CUDA)54)55return skip_but_pass_in_sandcastle_if(56not has_multicast_support,57"multicast support is not available",58)59
60
61@instantiate_parametrized_tests
62@requires_cuda_p2p_access()63class SymmetricMemoryTest(MultiProcessTestCase):64def setUp(self) -> None:65super().setUp()66self._spawn_processes()67
68@property69def world_size(self) -> int:70return 271
72@property73def device(self) -> torch.device:74return torch.device(f"cuda:{self.rank}")75
76def _init_process(self):77torch.cuda.set_device(self.device)78store = dist.FileStore(self.file_name, self.world_size)79dist.init_process_group(80backend="nccl",81world_size=self.world_size,82rank=self.rank,83store=store,84)85enable_symm_mem_for_group(dist.group.WORLD.group_name)86
87def _verify_symmetric_memory(self, symm_mem):88self.assertEqual(symm_mem.world_size, 2)89
90buf = symm_mem.get_buffer(0, (64, 64), torch.float32)91if symm_mem.rank == 0:92symm_mem.wait_signal(src_rank=1)93self.assertTrue(buf.eq(42).all())94else:95buf.fill_(42)96symm_mem.put_signal(dst_rank=0)97
98symm_mem.barrier()99
100if symm_mem.rank == 0:101symm_mem.barrier()102self.assertTrue(buf.eq(43).all())103else:104buf.fill_(43)105symm_mem.barrier()106
107symm_mem.barrier()108
109@skipIfRocm110@skip_if_lt_x_gpu(2)111def test_cuda_nvlink_connectivity_detection(self) -> None:112from torch._C._distributed_c10d import _detect_dma_connectivity113
114connectivity = _detect_dma_connectivity(DeviceType.CUDA, "nvlink")115self.assertEqual(connectivity.device_type, DeviceType.CUDA)116self.assertEqual(connectivity.connection_type, "nvlink")117self.assertEqual(len(connectivity.matrix), torch.cuda.device_count())118for row in connectivity.matrix:119self.assertEqual(len(row), torch.cuda.device_count())120
121@skipIfRocm122@skip_if_lt_x_gpu(2)123def test_empty_strided_p2p(self) -> None:124self._init_process()125
126shape = (64, 64)127stride = (64, 1)128dtype = torch.float32129device = self.device130group_name = "0"131alloc_args = (shape, stride, dtype, device, group_name)132
133t = torch.empty(shape, dtype=dtype, device=device)134self.assertIsNone(_SymmetricMemory.rendezvous(t))135
136t = _SymmetricMemory.empty_strided_p2p(*alloc_args)137symm_mem = _SymmetricMemory.rendezvous(t)138
139del t140self._verify_symmetric_memory(symm_mem)141dist.destroy_process_group()142
143@skipIfRocm144@skip_if_lt_x_gpu(2)145def test_empty_strided_p2p_persistent(self) -> None:146self._init_process()147
148shape = (64, 64)149stride = (64, 1)150dtype = torch.float32151device = self.device152alloc_id = 42 # Persistent allocation153group_name = "0"154alloc_args = (shape, stride, dtype, device, group_name, alloc_id)155
156t = _SymmetricMemory.empty_strided_p2p(*alloc_args)157data_ptr = t.data_ptr()158
159# Verify that persistent allocation would fail if there's an active160# allocation with the same alloc_id.161with self.assertRaises(RuntimeError):162_SymmetricMemory.empty_strided_p2p(*alloc_args)163
164# Verify that persistent allocation would succeed in lieu of activate165# allocations with the same alloc_id, and the returned tensor would166# have the same data pointer.167del t168t = _SymmetricMemory.empty_strided_p2p(*alloc_args)169self.assertEqual(t.data_ptr(), data_ptr)170
171# Verify that get_symmetric_memory would fail if called before172# rendezvous.173with self.assertRaises(RuntimeError):174_SymmetricMemory.get_symmetric_memory(t)175
176symm_mem_0 = _SymmetricMemory.rendezvous(t)177symm_mem_1 = _SymmetricMemory.get_symmetric_memory(t)178self.assertEqual(id(symm_mem_0), id(symm_mem_1))179
180self._verify_symmetric_memory(symm_mem_0)181dist.destroy_process_group()182
183@skipIfRocm184@skip_if_lt_x_gpu(2)185@parametrize("gather_dim", [0, 1])186def test_fused_all_gather_matmul(self, gather_dim: int) -> None:187self._init_process()188
189BATCH = 8190M = 64191N = 16192K = 32193group = dist.group.WORLD194rank = self.rank195world_size = self.world_size196
197torch.manual_seed(42 + rank)198A_shard = torch.rand(BATCH, M // self.world_size, K, device="cuda")199Bs = [torch.rand(K, N, device="cuda") for _ in range(3)]200
201ag_output_0, mm_outputs_0 = _fused_all_gather_matmul_fallback(202A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name203)204ag_output_1, mm_outputs_1 = torch.ops.symm_mem.fused_all_gather_matmul(205A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name206)207
208assert torch.allclose(ag_output_0, ag_output_1)209assert ag_output_0.stride() == ag_output_1.stride()210for mm_output_0, mm_output_1 in zip(mm_outputs_0, mm_outputs_1):211assert torch.allclose(mm_output_0, mm_output_1)212assert mm_output_0.stride(), mm_output_1.stride()213
214dist.destroy_process_group()215
216@skipIfRocm217@skip_if_lt_x_gpu(2)218@parametrize("gather_dim", [0, 1])219def test_fused_all_gather_scaled_matmul(self, gather_dim: int) -> None:220self._init_process()221
222BATCH = 8223M = 64224N = 16225K = 32226group = dist.group.WORLD227rank = self.rank228world_size = self.world_size229
230torch.manual_seed(42 + rank)231A_shard = torch.rand(BATCH, M // self.world_size, K, device="cuda").to(232torch.float8_e4m3fn233)234A_scale = torch.tensor(0.1, device="cuda")235Bs = [236torch.rand(N, K, device="cuda").to(torch.float8_e4m3fn).T for _ in range(3)237]238B_scales = [torch.tensor(0.1, device="cuda") for _ in range(3)]239out_dtypes = [None, torch.bfloat16, torch.float32]240
241ag_output_0, mm_outputs_0 = _fused_all_gather_scaled_matmul_fallback(242A_shard,243Bs,244A_scale,245B_scales,246gather_dim=gather_dim,247group_name=group.group_name,248biases=[None] * len(Bs),249result_scales=[None] * len(Bs),250out_dtypes=out_dtypes,251use_fast_accum=[None] * len(Bs),252)253ag_output_1, mm_outputs_1 = torch.ops.symm_mem.fused_all_gather_scaled_matmul(254A_shard,255Bs,256A_scale,257B_scales,258gather_dim=gather_dim,259group_name=group.group_name,260biases=[None] * len(Bs),261result_scales=[None] * len(Bs),262out_dtypes=out_dtypes,263use_fast_accum=[None] * len(Bs),264)265
266self.assertTrue(267torch.allclose(268ag_output_0.to(torch.float32),269ag_output_1.to(torch.float32),270)271)272self.assertEqual(ag_output_0.stride(), ag_output_1.stride())273for mm_output_0, mm_output_1 in zip(mm_outputs_0, mm_outputs_1):274self.assertTrue(275torch.allclose(276mm_output_0.to(torch.float32), mm_output_1.to(torch.float32)277)278)279self.assertEqual(mm_output_0.stride(), mm_output_1.stride())280self.assertEqual(mm_output_0.dtype, mm_output_1.dtype)281
282dist.destroy_process_group()283
284@skipIfRocm285@skip_if_lt_x_gpu(2)286@parametrize("scatter_dim", [0, 1])287def test_fused_matmul_reduce_scatter(self, scatter_dim: int) -> None:288self._init_process()289
290BATCH = 8291M = 64292N = 16293K = 32294group = dist.group.WORLD295rank = self.rank296world_size = self.world_size297
298torch.manual_seed(42 + rank)299A = torch.rand(BATCH, M, K, device="cuda")300B = torch.rand(K, N, device="cuda")301
302output_0 = _fused_matmul_reduce_scatter_fallback(303A, B, "avg", scatter_dim=scatter_dim, group_name=group.group_name304)305output_1 = torch.ops.symm_mem.fused_matmul_reduce_scatter(306A, B, "avg", scatter_dim=scatter_dim, group_name=group.group_name307)308
309assert torch.allclose(output_0, output_1)310assert output_0.stride() == output_1.stride()311
312dist.destroy_process_group()313
314@skipIfRocm315@skip_if_lt_x_gpu(2)316@parametrize("scatter_dim", [0, 1])317def test_fused_scaled_matmul_reduce_scatter(self, scatter_dim: int) -> None:318self._init_process()319
320BATCH = 8321M = 64322N = 16323K = 32324group = dist.group.WORLD325rank = self.rank326world_size = self.world_size327
328torch.manual_seed(42 + rank)329A = torch.rand(BATCH, M, K, device="cuda").to(torch.float8_e4m3fn)330A_scale = torch.tensor(0.1, device="cuda")331B = torch.rand(N, K, device="cuda").to(torch.float8_e4m3fn).T332B_scale = torch.tensor(0.1, device="cuda")333
334output_0 = _fused_scaled_matmul_reduce_scatter_fallback(335A,336B,337A_scale,338B_scale,339"avg",340scatter_dim,341group.group_name,342out_dtype=torch.bfloat16,343)344output_1 = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(345A,346B,347A_scale,348B_scale,349"avg",350scatter_dim,351group.group_name,352out_dtype=torch.bfloat16,353)354
355assert torch.allclose(output_0, output_1)356assert output_0.stride() == output_1.stride()357
358dist.destroy_process_group()359
360@skipIfRocm361@parametrize("dim", [0, 1, 2])362def test_optimal_layout(self, dim: int) -> None:363t = torch.rand(8, 64, 32, 16)364
365x = restride_A_shard_for_fused_all_gather_matmul(t, dim)366self.assertTrue(x.movedim(dim, 0).is_contiguous())367self.assertTrue(torch.allclose(x, t))368
369x = restride_A_for_fused_matmul_reduce_scatter(t, dim)370self.assertTrue(x.movedim(dim, 0).is_contiguous())371self.assertTrue(torch.allclose(x, t))372
373@skipIfRocm374@skip_if_lt_x_gpu(2)375@parametrize("symm_mem_input", [True, False])376def test_low_contention_all_gather(self, symm_mem_input: bool) -> None:377self._init_process()378
379if symm_mem_input:380t = _SymmetricMemory.empty_strided_p2p(381size=(64, 64),382stride=(64, 1),383dtype=torch.float32,384device=self.device,385group_name="0",386).fill_(self.rank)387else:388t = torch.full((64, 64), self.rank, dtype=torch.float32, device=self.device)389
390res = torch.ops.symm_mem._low_contention_all_gather(t, "0")391res = torch.ops._c10d_functional.wait_tensor(res)392self.assertEqual(res.shape, (64 * self.world_size, 64))393
394chunks = res.chunk(self.world_size)395for r in range(self.world_size):396self.assertTrue(chunks[r].eq(r).all())397
398dist.destroy_process_group()399
400@skipIfRocm401@skip_if_lt_x_gpu(2)402@parametrize("reduce_op", ["sum", "avg"])403@parametrize("symm_mem_input", [True, False])404def test_low_contention_reduce_scatter(405self, reduce_op: str, symm_mem_input: bool406) -> None:407self._init_process()408
409if symm_mem_input:410t = _SymmetricMemory.empty_strided_p2p(411size=(64, 64),412stride=(64, 1),413dtype=torch.float32,414device=self.device,415group_name="0",416)417else:418t = torch.empty((64, 64), dtype=torch.float32, device=self.device)419
420chunks = t.chunk(self.world_size)421for r in range(self.world_size):422chunks[r].fill_(r)423
424res = torch.ops.symm_mem._low_contention_reduce_scatter(t, reduce_op, "0")425res = torch.ops._c10d_functional.wait_tensor(res)426self.assertEqual(res.shape, (64 // self.world_size, 64))427
428if reduce_op == "sum":429expect = self.rank * self.world_size430elif reduce_op == "avg":431expect = self.rank432else:433raise AssertionError(f"Unexpected reduce_op: {reduce_op}")434self.assertTrue(res.eq(expect).all())435
436dist.destroy_process_group()437
438@skip_if_lt_x_gpu(2)439@requires_multicast_support()440@parametrize("dtype", [torch.float, torch.bfloat16])441@parametrize("align_bytes", [4, 8, 16])442@parametrize("size_bytes", [4, 8192, 8196])443def test_multimem_all_reduce(444self, dtype: torch.dtype, size_bytes: int, align_bytes: int445) -> None:446self._init_process()447group_name = dist.group.WORLD.group_name448
449t = _SymmetricMemory.empty_strided_p2p(450size=(16384,),451stride=(1,),452dtype=dtype,453device=self.device,454group_name=group_name,455).fill_(1)456
457self.assertTrue(t.data_ptr() % 16 == 0)458self.assertTrue(align_bytes % t.element_size() == 0)459self.assertTrue(size_bytes % t.element_size() == 0)460
461shift = align_bytes // t.element_size()462numel = size_bytes // t.element_size()463x = t[shift : shift + numel]464
465torch.ops.symm_mem.multimem_all_reduce_(x, "sum", group_name)466self.assertTrue(x.eq(self.world_size).all().item())467
468# Head and tail should not be written469self.assertTrue(t[:shift].eq(1).all().item())470self.assertTrue(t[shift + numel :].eq(1).all().item())471dist.destroy_process_group()472
473@skip_if_lt_x_gpu(2)474@requires_multicast_support()475@parametrize("dtype", [torch.float, torch.bfloat16])476@parametrize("align_bytes", [4, 8, 16])477@parametrize("size_bytes", [4, 8192, 8196])478def test_multimem_one_shot_all_reduce(479self, dtype: torch.dtype, size_bytes: int, align_bytes: int480) -> None:481self._init_process()482group_name = dist.group.WORLD.group_name483
484t = _SymmetricMemory.empty_strided_p2p(485size=(16384,),486stride=(1,),487dtype=dtype,488device=self.device,489group_name=group_name,490).fill_(0)491
492self.assertTrue(t.data_ptr() % 16 == 0)493self.assertTrue(align_bytes % t.element_size() == 0)494self.assertTrue(size_bytes % t.element_size() == 0)495
496shift = align_bytes // t.element_size()497numel = size_bytes // t.element_size()498x = t[shift : shift + numel]499x.fill_(1)500
501res = torch.ops.symm_mem.multimem_one_shot_all_reduce(x, "sum", group_name)502self.assertTrue(res.eq(self.world_size).all().item())503dist.destroy_process_group()504
505
506if __name__ == "__main__":507run_tests()508