pytorch

Форк
0
341 строка · 14.3 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import sys
4

5
import torch
6

7
from torch.distributed._shard.sharded_tensor import (
8
    Shard,
9
    ShardedTensor,
10
    ShardedTensorMetadata,
11
    ShardMetadata,
12
)
13
from torch.distributed._shard.sharded_tensor.metadata import (
14
    TensorProperties as TensorProperties_Shard,
15
)
16
from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans
17

18
from torch.distributed.checkpoint.default_planner import (
19
    _create_default_local_metadata,
20
    create_default_global_save_plan,
21
    create_default_local_load_plan,
22
    create_default_local_save_plan,
23
)
24
from torch.distributed.checkpoint.metadata import (
25
    BytesStorageMetadata,
26
    ChunkStorageMetadata,
27
    MetadataIndex,
28
    TensorProperties,
29
    TensorStorageMetadata,
30
)
31
from torch.distributed.checkpoint.planner import LoadItemType, WriteItemType
32

33
from torch.distributed.checkpoint.planner_helpers import (
34
    create_read_items_for_chunk_list,
35
)
36

37
from torch.testing._internal.common_utils import (
38
    run_tests,
39
    TEST_WITH_DEV_DBG_ASAN,
40
    TestCase,
41
)
42

43
from torch.testing._internal.distributed.distributed_utils import (
44
    with_dist,
45
    with_fake_comms,
46
)
47

48

49
if TEST_WITH_DEV_DBG_ASAN:
50
    print(
51
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
52
        file=sys.stderr,
53
    )
54
    sys.exit(0)
55

56

57
def create_sharded_tensor(rank, world_size, shards_per_rank, shard_size=8):
58
    shards_metadata = []
59
    local_shards = []
60
    for idx in range(0, world_size * shards_per_rank):
61
        shard_rank = idx // shards_per_rank
62
        shard_md = ShardMetadata(
63
            shard_offsets=[idx * shard_size],
64
            shard_sizes=[shard_size],
65
            placement=f"rank:{shard_rank}/cpu",
66
        )
67
        shards_metadata.append(shard_md)
68
        if shard_rank == rank:
69
            shard = Shard.from_tensor_and_offsets(
70
                torch.rand(*shard_md.shard_sizes),
71
                shard_offsets=shard_md.shard_offsets,
72
                rank=rank,
73
            )
74
            local_shards.append(shard)
75

76
    sharded_tensor_md = ShardedTensorMetadata(
77
        shards_metadata=shards_metadata,
78
        size=torch.Size([shard_size * len(shards_metadata)]),
79
        tensor_properties=TensorProperties_Shard.create_from_tensor(torch.zeros(1)),
80
    )
81

82
    return ShardedTensor._init_from_local_shards_and_global_metadata(
83
        local_shards=local_shards, sharded_tensor_metadata=sharded_tensor_md
84
    )
85

86

87
class TestSavePlan(TestCase):
88
    @with_fake_comms(rank=1, world_size=4)
89
    def test_local_plan(self):
90
        tensor = torch.rand(10)
91
        val = [1, 2, 3]
92
        st = create_sharded_tensor(rank=1, world_size=4, shards_per_rank=1)
93
        state_dict = {"tensor": tensor, "value": val, "st": st}
94
        plan = create_default_local_save_plan(state_dict, False)
95
        self.assertEqual(2, len(plan.items))
96
        wi = plan.items[0]
97
        self.assertEqual(wi.index, MetadataIndex("tensor", [0]))
98
        self.assertEqual(wi.type, WriteItemType.TENSOR)
99
        self.assertEqual(wi.tensor_data.size, tensor.size())
100
        self.assertEqual(
101
            wi.tensor_data.properties,
102
            TensorProperties.create_from_tensor(torch.zeros(1)),
103
        )
104
        self.assertEqual(wi.tensor_data.chunk.offsets, torch.Size([0]))
105
        self.assertEqual(wi.tensor_data.chunk.sizes, torch.Size([10]))
106

107
        st_wi = plan.items[1]
108
        self.assertEqual(st_wi.index, MetadataIndex("st", [8]))
109
        self.assertEqual(st_wi.type, WriteItemType.SHARD)
110
        self.assertEqual(st_wi.tensor_data.size, st.size())
111
        self.assertEqual(
112
            st_wi.tensor_data.properties,
113
            TensorProperties.create_from_tensor(torch.zeros(1)),
114
        )
115
        self.assertEqual(st_wi.tensor_data.chunk.offsets, torch.Size([8]))
116
        self.assertEqual(st_wi.tensor_data.chunk.sizes, torch.Size([8]))
117

118
        # Coordinator rank, should include replicated items as well
119
        plan = create_default_local_save_plan(state_dict, True)
120
        self.assertEqual(3, len(plan.items))
121

122
        tensor_wi = next(wi for wi in plan.items if wi.type == WriteItemType.TENSOR)
123
        self.assertEqual(tensor_wi.index, MetadataIndex("tensor", [0]))
124
        self.assertEqual(tensor_wi.tensor_data.size, tensor.size())
125
        self.assertEqual(
126
            tensor_wi.tensor_data.properties,
127
            TensorProperties.create_from_tensor(tensor),
128
        )
129
        self.assertEqual(tensor_wi.tensor_data.chunk.offsets, torch.Size([0]))
130
        self.assertEqual(tensor_wi.tensor_data.chunk.sizes, torch.Size([10]))
131

132
        bytes_wi = next(wi for wi in plan.items if wi.type == WriteItemType.BYTE_IO)
133
        self.assertEqual(bytes_wi.index, MetadataIndex("value"))
134
        self.assertIsNone(bytes_wi.tensor_data)
135

136
    def test_global_plan(self):
137
        def create_data(rank):
138
            with with_dist(rank=rank, world_size=4):
139
                tensor = torch.rand(10)
140
                val = [1, 2, 3]
141
                st = create_sharded_tensor(rank=rank, world_size=4, shards_per_rank=1)
142
                state_dict = {"tensor": tensor, "value": val, "st": st}
143
                return create_default_local_save_plan(state_dict, rank == 0)
144

145
        all_plans = [create_data(0), create_data(1), create_data(2), create_data(3)]
146
        all_plans = dedup_save_plans(all_plans)
147
        final_plans, metadata = create_default_global_save_plan(all_plans=all_plans)
148

149
        # The default global plan updates all indexes to include hints
150
        for new_plan, old_plan in zip(final_plans, all_plans):
151
            for new_item, old_item in zip(new_plan.items, old_plan.items):
152
                self.assertEqual(new_item.index, old_item.index)
153
                self.assertEqual(new_item.type, old_item.type)
154
                self.assertEqual(new_item.tensor_data, old_item.tensor_data)
155
                self.assertIn(new_item.index.fqn, metadata.state_dict_metadata)
156

157
                item_md = metadata.state_dict_metadata[new_item.index.fqn]
158
                if new_item.type == WriteItemType.BYTE_IO:
159
                    self.assertTrue(isinstance(item_md, BytesStorageMetadata))
160
                else:
161
                    self.assertTrue(isinstance(item_md, TensorStorageMetadata))
162
                    self.assertEqual(item_md.size, old_item.tensor_data.size)
163
                    self.assertEqual(
164
                        item_md.properties, old_item.tensor_data.properties
165
                    )
166

167
                    self.assertIsNotNone(new_item.index.index)
168
                    # Make sure the hint is correct
169
                    self.assertEqual(
170
                        item_md.chunks[new_item.index.index], old_item.tensor_data.chunk
171
                    )
172

173
    def test_local_load_plan(self):
174
        def create_state_dict(rank):
175
            with with_dist(rank=rank, world_size=4):
176
                tensor = torch.rand(10)
177
                val = [1, 2, 3]
178
                st = create_sharded_tensor(rank=rank, world_size=4, shards_per_rank=1)
179
                return {"tensor": tensor, "value": val, "st": st}
180

181
        state_dict = create_state_dict(1)
182
        metadata = _create_default_local_metadata(state_dict)
183

184
        load_plan = create_default_local_load_plan(state_dict, metadata)
185
        # This will create 3 entries
186
        self.assertEqual(3, len(load_plan.items))
187
        st_item = next(ri for ri in load_plan.items if ri.dest_index.fqn == "st")
188
        tensor_item = next(
189
            ri for ri in load_plan.items if ri.dest_index.fqn == "tensor"
190
        )
191
        bytes_item = next(ri for ri in load_plan.items if ri.dest_index.fqn == "value")
192

193
        self.assertEqual(st_item.type, LoadItemType.TENSOR)
194
        # This is an exact copy
195
        self.assertEqual(st_item.dest_index, MetadataIndex("st", [8]))
196
        self.assertEqual(st_item.dest_offsets, torch.Size([0]))
197
        self.assertEqual(st_item.storage_index, MetadataIndex("st", [8]))
198
        self.assertEqual(st_item.storage_offsets, torch.Size([0]))
199
        self.assertEqual(st_item.lengths, torch.Size([8]))
200

201
        self.assertEqual(tensor_item.type, LoadItemType.TENSOR)
202
        self.assertEqual(tensor_item.dest_index, MetadataIndex("tensor", [0]))
203
        self.assertEqual(tensor_item.dest_offsets, torch.Size([0]))
204
        self.assertEqual(tensor_item.storage_index, MetadataIndex("tensor", [0]))
205
        self.assertEqual(tensor_item.storage_offsets, torch.Size([0]))
206
        self.assertEqual(tensor_item.lengths, torch.Size([10]))
207

208
        self.assertEqual(bytes_item.type, LoadItemType.BYTE_IO)
209
        self.assertEqual(bytes_item.dest_index, MetadataIndex("value"))
210

211
    def test_load_with_resharding(self):
212
        def create_state_dict(rank, world_size):
213
            with with_dist(rank=rank, world_size=world_size):
214
                return {
215
                    "st": create_sharded_tensor(
216
                        rank=rank,
217
                        world_size=world_size,
218
                        shards_per_rank=1,
219
                        shard_size=128 // world_size,
220
                    )
221
                }
222

223
        # Rank 1 has a 16 bytes shard from [16, 32[
224
        world8_state_dict = create_state_dict(rank=1, world_size=8)
225
        world8_metadata = _create_default_local_metadata(world8_state_dict)
226

227
        # Rank 1 has a 32 bytes shard from [32, 64[
228
        world4_state_dict = create_state_dict(rank=1, world_size=4)
229
        world4_metadata = _create_default_local_metadata(world4_state_dict)
230

231
        # First scenario, going from world=8 to world=4, need to load 2 shards
232
        # Each 4-world shard has 32 elements, so it needs to load 2 shards
233
        load_plan = create_default_local_load_plan(world4_state_dict, world8_metadata)
234
        self.assertEqual(2, len(load_plan.items))
235
        low_ri = next(
236
            ri for ri in load_plan.items if ri.dest_offsets == torch.Size([0])
237
        )
238
        high_ri = next(
239
            ri for ri in load_plan.items if ri.dest_offsets == torch.Size([16])
240
        )
241

242
        self.assertEqual(low_ri.storage_index, MetadataIndex("st", [32]))
243
        self.assertEqual(low_ri.storage_offsets, torch.Size([0]))
244
        self.assertEqual(low_ri.dest_index, MetadataIndex("st", [32]))
245
        self.assertEqual(low_ri.dest_offsets, torch.Size([0]))
246
        self.assertEqual(low_ri.lengths, torch.Size([16]))
247

248
        self.assertEqual(high_ri.storage_index, MetadataIndex("st", [48]))
249
        self.assertEqual(high_ri.storage_offsets, torch.Size([0]))
250
        self.assertEqual(high_ri.dest_index, MetadataIndex("st", [32]))
251
        self.assertEqual(high_ri.dest_offsets, torch.Size([16]))
252
        self.assertEqual(high_ri.lengths, torch.Size([16]))
253

254
        # Second scenario, going from world=4 to world=8, need to load half of 1 shard
255
        # rank1 on 8-world needs to load the upper half of the rank0 4-world shard
256
        load_plan = create_default_local_load_plan(world8_state_dict, world4_metadata)
257
        self.assertEqual(1, len(load_plan.items))
258
        ri = load_plan.items[0]
259
        self.assertEqual(ri.storage_index, MetadataIndex("st", [0]))
260
        self.assertEqual(ri.storage_offsets, torch.Size([16]))
261
        self.assertEqual(ri.dest_index, MetadataIndex("st", [16]))
262
        self.assertEqual(ri.dest_offsets, torch.Size([0]))
263
        self.assertEqual(ri.lengths, torch.Size([16]))
264

265
    def test_load_with_world_size_diff_by_one(self):
266
        def create_state_dict(rank, world_size):
267
            with with_dist(rank=rank, world_size=world_size):
268
                return {
269
                    "st": create_sharded_tensor(
270
                        rank=rank,
271
                        world_size=world_size,
272
                        shards_per_rank=1,
273
                        shard_size=120 // world_size,
274
                    )
275
                }
276

277
        # rank 1 has a 30 bytes shard from [30, 60[
278
        world4_state_dict = create_state_dict(rank=1, world_size=4)
279
        world4_metadata = _create_default_local_metadata(world4_state_dict)
280

281
        # rank 1 has a 40 bytes shard from [40, 80[
282
        world3_state_dict = create_state_dict(rank=1, world_size=3)
283

284
        load_plan = create_default_local_load_plan(world3_state_dict, world4_metadata)
285
        self.assertEqual(2, len(load_plan.items))
286
        # this is [30, 60] to load [40, 60]
287
        low_ri = next(
288
            ri for ri in load_plan.items if ri.dest_offsets == torch.Size([0])
289
        )
290
        # this is [60, 90] to load [60, 80]
291
        high_ri = next(
292
            ri for ri in load_plan.items if ri.dest_offsets == torch.Size([20])
293
        )
294

295
        self.assertEqual(low_ri.storage_index, MetadataIndex("st", [30]))
296
        self.assertEqual(low_ri.storage_offsets, torch.Size([10]))
297
        self.assertEqual(low_ri.dest_index, MetadataIndex("st", [40]))
298
        self.assertEqual(low_ri.dest_offsets, torch.Size([0]))
299
        self.assertEqual(low_ri.lengths, torch.Size([20]))
300

301
        self.assertEqual(high_ri.storage_index, MetadataIndex("st", [60]))
302
        self.assertEqual(high_ri.storage_offsets, torch.Size([0]))
303
        self.assertEqual(high_ri.dest_index, MetadataIndex("st", [40]))
304
        self.assertEqual(high_ri.dest_offsets, torch.Size([20]))
305
        self.assertEqual(high_ri.lengths, torch.Size([20]))
306

307

308
class TestPlannerHelpers(TestCase):
309
    def test_create_read_item_from_chunks(self):
310
        tensor_md = TensorStorageMetadata(
311
            properties=TensorProperties.create_from_tensor(torch.empty([16])),
312
            size=torch.Size([16]),
313
            chunks=[
314
                ChunkStorageMetadata(offsets=torch.Size([0]), sizes=torch.Size([8])),
315
                ChunkStorageMetadata(offsets=torch.Size([8]), sizes=torch.Size([8])),
316
            ],
317
        )
318

319
        chunk = ChunkStorageMetadata(offsets=torch.Size([4]), sizes=torch.Size([7]))
320
        read_items = create_read_items_for_chunk_list("foo", tensor_md, [chunk])
321

322
        self.assertEqual(2, len(read_items))
323
        self.assertEqual(MetadataIndex("foo", [4]), read_items[0].dest_index)
324
        self.assertEqual(torch.Size([0]), read_items[0].dest_offsets)
325

326
        self.assertEqual(MetadataIndex("foo", [0]), read_items[0].storage_index)
327
        self.assertEqual(torch.Size([4]), read_items[0].storage_offsets)
328

329
        self.assertEqual(torch.Size([4]), read_items[0].lengths)
330

331
        self.assertEqual(MetadataIndex("foo", [4]), read_items[1].dest_index)
332
        self.assertEqual(torch.Size([4]), read_items[1].dest_offsets)
333

334
        self.assertEqual(MetadataIndex("foo", [8]), read_items[1].storage_index)
335
        self.assertEqual(torch.Size([0]), read_items[1].storage_offsets)
336

337
        self.assertEqual(torch.Size([3]), read_items[1].lengths)
338

339

340
if __name__ == "__main__":
341
    run_tests()
342

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.