pytorch

Форк
0
127 строк · 4.3 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import sys
4

5
import torch
6

7
from torch.distributed._shard.sharded_tensor import (
8
    Shard,
9
    ShardedTensor,
10
    ShardedTensorMetadata,
11
    ShardMetadata,
12
)
13
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
14
from torch.distributed.checkpoint.metadata import MetadataIndex
15
from torch.distributed.checkpoint.utils import find_state_dict_object
16

17
from torch.testing._internal.common_utils import (
18
    run_tests,
19
    TEST_WITH_DEV_DBG_ASAN,
20
    TestCase,
21
)
22
from torch.testing._internal.distributed.distributed_utils import with_fake_comms
23

24
if TEST_WITH_DEV_DBG_ASAN:
25
    print(
26
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
27
        file=sys.stderr,
28
    )
29
    sys.exit(0)
30

31

32
def create_sharded_tensor(rank, world_size, shards_per_rank):
33
    shards_metadata = []
34
    local_shards = []
35
    for idx in range(0, world_size * shards_per_rank):
36
        shard_rank = idx // shards_per_rank
37
        shard_md = ShardMetadata(
38
            shard_offsets=[idx * 8], shard_sizes=[8], placement=f"rank:{shard_rank}/cpu"
39
        )
40
        shards_metadata.append(shard_md)
41
        if shard_rank == rank:
42
            shard = Shard.from_tensor_and_offsets(
43
                torch.rand(*shard_md.shard_sizes),
44
                shard_offsets=shard_md.shard_offsets,
45
                rank=rank,
46
            )
47
            local_shards.append(shard)
48

49
    sharded_tensor_md = ShardedTensorMetadata(
50
        shards_metadata=shards_metadata,
51
        size=torch.Size([8 * len(shards_metadata)]),
52
        tensor_properties=TensorProperties.create_from_tensor(torch.zeros(1)),
53
    )
54

55
    return ShardedTensor._init_from_local_shards_and_global_metadata(
56
        local_shards=local_shards, sharded_tensor_metadata=sharded_tensor_md
57
    )
58

59

60
class TestMedatadaIndex(TestCase):
61
    def test_init_convert_offset(self):
62
        a = MetadataIndex("foo", [1, 2])
63
        b = MetadataIndex("foo", torch.Size([1, 2]))
64
        self.assertEqual(a, b)
65

66
    def test_index_hint_ignored_on_equals(self):
67
        a = MetadataIndex("foo")
68
        b = MetadataIndex("foo", index=99)
69
        self.assertEqual(a, b)
70

71
    def test_index_hint_ignored_on_hash(self):
72
        a = MetadataIndex("foo")
73
        b = MetadataIndex("foo", index=99)
74
        self.assertEqual(hash(a), hash(b))
75

76
    def test_flat_data(self):
77
        state_dict = {
78
            "a": torch.rand(10),
79
            "b": [1, 2, 3],
80
        }
81

82
        a = find_state_dict_object(state_dict, MetadataIndex("a"))
83
        self.assertEqual(a, state_dict["a"])
84
        a = find_state_dict_object(state_dict, MetadataIndex("a", [0]))
85
        self.assertEqual(a, state_dict["a"])
86
        a = find_state_dict_object(state_dict, MetadataIndex("a", index=99))
87
        self.assertEqual(a, state_dict["a"])
88

89
        b = find_state_dict_object(state_dict, MetadataIndex("b"))
90
        self.assertEqual(b, state_dict["b"])
91
        b = find_state_dict_object(state_dict, MetadataIndex("b", index=1))
92
        self.assertEqual(b, state_dict["b"])
93

94
        with self.assertRaisesRegex(ValueError, "FQN"):
95
            find_state_dict_object(state_dict, MetadataIndex("c"))
96
        with self.assertRaisesRegex(ValueError, "ShardedTensor"):
97
            find_state_dict_object(state_dict, MetadataIndex("b", [1]))
98

99
    @with_fake_comms(rank=0, world_size=2)
100
    def test_sharded_tensor_lookup(self):
101
        st = create_sharded_tensor(rank=0, world_size=2, shards_per_rank=3)
102
        state_dict = {"st": st}
103

104
        obj = find_state_dict_object(state_dict, MetadataIndex("st", [8]))
105
        self.assertEqual(obj, st.local_shards()[1].tensor)
106

107
        # good hint
108
        obj = find_state_dict_object(state_dict, MetadataIndex("st", [8], index=1))
109
        self.assertEqual(obj, st.local_shards()[1].tensor)
110

111
        # bad hint
112
        obj = find_state_dict_object(state_dict, MetadataIndex("st", [8], index=2))
113
        self.assertEqual(obj, st.local_shards()[1].tensor)
114

115
        # broken hint
116
        obj = find_state_dict_object(state_dict, MetadataIndex("st", [8], index=99))
117
        self.assertEqual(obj, st.local_shards()[1].tensor)
118

119
        with self.assertRaisesRegex(ValueError, "no offset was provided"):
120
            find_state_dict_object(state_dict, MetadataIndex("st"))
121

122
        with self.assertRaisesRegex(ValueError, "Could not find shard"):
123
            find_state_dict_object(state_dict, MetadataIndex("st", [1]))
124

125

126
if __name__ == "__main__":
127
    run_tests()
128

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.