pytorch
76 строк · 2.3 Кб
1# Owner(s): ["oncall: distributed"]
2
3import torch4
5from torch.distributed.distributed_c10d import _get_default_group6from torch.distributed.fsdp._shard_utils import (7_create_chunk_dtensor,8_create_chunk_sharded_tensor,9)
10from torch.testing._internal.common_fsdp import FSDPTest11from torch.testing._internal.common_utils import run_tests12from torch.testing._internal.distributed._tensor.common_dtensor import (13DTensorTestBase,14skip_if_lt_x_gpu,15with_comms,16)
17
18
19class TestShardUtilsDistributed(FSDPTest):20@property21def world_size(self):22return 223
24def _create_tensor(self, *size):25# Keep everything deterministic.26torch.manual_seed(0)27return torch.rand(*size).cuda()28
29@skip_if_lt_x_gpu(2)30def test_create_chunk_sharded_tensor(self):31for size in ((1,), (1, 6), (12,), (12, 6), (25,), (25, 6)):32tensor = self._create_tensor(*size)33
34sharded_tensor = _create_chunk_sharded_tensor(35tensor,36self.rank,37self.world_size,38torch.cuda.device_count(),39_get_default_group(),40)41output = torch.empty(*size).cuda() if self.rank == 0 else None42sharded_tensor.gather(0, output)43if self.rank == 0:44self.assertEqual(tensor, output)45
46
47class TestShardUtilsDistributedDTensor(DTensorTestBase):48@property49def world_size(self):50return 251
52def _create_tensor(self, *size):53# Keep everything deterministic.54torch.manual_seed(0)55return torch.rand(*size).cuda()56
57@with_comms58@skip_if_lt_x_gpu(2)59def test_create_chunk_dtensor(self):60device_mesh = self.build_device_mesh()61
62for size in ((1,), (1, 6), (12,), (12, 6), (25,), (25, 6)):63tensor = self._create_tensor(*size)64tensor_chunks = torch.chunk(tensor, self.world_size, dim=0)65
66dtensor = _create_chunk_dtensor(tensor, self.rank, device_mesh)67local_tensor = dtensor.to_local()68
69if local_tensor.numel() != 0:70self.assertEqual(local_tensor, tensor_chunks[self.rank])71else:72self.assertEqual(self.rank >= len(tensor_chunks), True)73
74
75if __name__ == "__main__":76run_tests()77