pytorch

Форк
0
/
test_shard_utils.py 
76 строк · 2.3 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import torch
4

5
from torch.distributed.distributed_c10d import _get_default_group
6
from torch.distributed.fsdp._shard_utils import (
7
    _create_chunk_dtensor,
8
    _create_chunk_sharded_tensor,
9
)
10
from torch.testing._internal.common_fsdp import FSDPTest
11
from torch.testing._internal.common_utils import run_tests
12
from torch.testing._internal.distributed._tensor.common_dtensor import (
13
    DTensorTestBase,
14
    skip_if_lt_x_gpu,
15
    with_comms,
16
)
17

18

19
class TestShardUtilsDistributed(FSDPTest):
20
    @property
21
    def world_size(self):
22
        return 2
23

24
    def _create_tensor(self, *size):
25
        # Keep everything deterministic.
26
        torch.manual_seed(0)
27
        return torch.rand(*size).cuda()
28

29
    @skip_if_lt_x_gpu(2)
30
    def test_create_chunk_sharded_tensor(self):
31
        for size in ((1,), (1, 6), (12,), (12, 6), (25,), (25, 6)):
32
            tensor = self._create_tensor(*size)
33

34
            sharded_tensor = _create_chunk_sharded_tensor(
35
                tensor,
36
                self.rank,
37
                self.world_size,
38
                torch.cuda.device_count(),
39
                _get_default_group(),
40
            )
41
            output = torch.empty(*size).cuda() if self.rank == 0 else None
42
            sharded_tensor.gather(0, output)
43
            if self.rank == 0:
44
                self.assertEqual(tensor, output)
45

46

47
class TestShardUtilsDistributedDTensor(DTensorTestBase):
48
    @property
49
    def world_size(self):
50
        return 2
51

52
    def _create_tensor(self, *size):
53
        # Keep everything deterministic.
54
        torch.manual_seed(0)
55
        return torch.rand(*size).cuda()
56

57
    @with_comms
58
    @skip_if_lt_x_gpu(2)
59
    def test_create_chunk_dtensor(self):
60
        device_mesh = self.build_device_mesh()
61

62
        for size in ((1,), (1, 6), (12,), (12, 6), (25,), (25, 6)):
63
            tensor = self._create_tensor(*size)
64
            tensor_chunks = torch.chunk(tensor, self.world_size, dim=0)
65

66
            dtensor = _create_chunk_dtensor(tensor, self.rank, device_mesh)
67
            local_tensor = dtensor.to_local()
68

69
            if local_tensor.numel() != 0:
70
                self.assertEqual(local_tensor, tensor_chunks[self.rank])
71
            else:
72
                self.assertEqual(self.rank >= len(tensor_chunks), True)
73

74

75
if __name__ == "__main__":
76
    run_tests()
77

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.