pytorch

Форк
0
/
test_file_system_checkpoint_cpu.py 
470 строк · 15.9 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import sys
4
import tempfile
5
from typing import Dict
6

7
import torch
8
import torch.distributed as dist
9
from torch.distributed._shard import sharded_tensor
10
from torch.distributed._shard.sharded_tensor import ShardedTensor, state_dict_hook
11
from torch.distributed._shard.sharding_spec import (
12
    ChunkShardingSpec,
13
    EnumerableShardingSpec,
14
    ShardingSpec,
15
    ShardMetadata,
16
)
17

18
from torch.distributed.checkpoint import (
19
    FileSystemReader,
20
    FileSystemWriter,
21
    load_state_dict,
22
    save_state_dict,
23
)
24

25
from torch.testing._internal.common_utils import (
26
    instantiate_parametrized_tests,
27
    parametrize,
28
    run_tests,
29
    TEST_WITH_DEV_DBG_ASAN,
30
    TestCase,
31
)
32
from torch.testing._internal.distributed._shard.sharded_tensor import (
33
    ShardedTensorTestBase,
34
    with_comms,
35
)
36
from torch.testing._internal.distributed._shard.sharded_tensor._test_st_common import (
37
    MyShardedModel1,
38
)
39

40

41
if TEST_WITH_DEV_DBG_ASAN:
42
    print(
43
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
44
        file=sys.stderr,
45
    )
46
    sys.exit(0)
47

48

49
_THREAD_COUNTS = {1, 2}
50

51

52
def assert_state_dict_equal(
53
    self: TestCase,
54
    state_dict_1: Dict[str, torch.Tensor],
55
    state_dict_2: Dict[str, torch.Tensor],
56
) -> bool:
57
    self.assertEqual(
58
        len(state_dict_1), len(state_dict_2), "state_dict must be the same size"
59
    )
60
    self.assertEqual(
61
        set(state_dict_1.keys()),
62
        set(state_dict_2.keys()),
63
        "state_dict keys do not match",
64
    )
65

66
    for key, value_1 in state_dict_1.items():
67
        value_2 = state_dict_2[key]
68
        if isinstance(value_1, ShardedTensor):
69
            for local_shard_1, local_shard_2 in zip(
70
                value_1.local_shards(), value_2.local_shards()
71
            ):
72
                self.assertTrue(
73
                    torch.equal(local_shard_1.tensor, local_shard_2.tensor),
74
                    f"Key {key}'s shard does not match",
75
                )
76
        elif isinstance(value_1, torch.Tensor):
77
            self.assertTrue(
78
                torch.equal(value_1, value_2),
79
                f"Key {key}'s tensor does not match",
80
            )
81

82
    return True
83

84

85
class MyTestModule(torch.nn.Module):
86
    def __init__(self) -> None:
87
        super().__init__()
88
        self.linear_1 = torch.nn.Linear(5, 5)
89
        self.linear_2 = torch.nn.Linear(5, 1)
90
        self.emb = torch.nn.EmbeddingBag(5, 10)
91

92

93
# The ShardedModels are borrowed from test/distributed/_sharded_tensor/test_sharded_tensor.py
94
class MyShardedModel3(torch.nn.Module):
95
    def __init__(
96
        self,
97
        spec: ShardingSpec,
98
    ) -> None:
99
        super().__init__()
100
        self.sharded_tensor: ShardedTensor = sharded_tensor.rand(
101
            spec, 10, 20, init_rrefs=False
102
        )
103

104

105
class TestDistributedStateDictSaveLoad(TestCase):
106
    @parametrize("thread_count", _THREAD_COUNTS)
107
    def test_read_write_only_tensor(self, thread_count) -> None:
108
        with tempfile.TemporaryDirectory() as path:
109
            state_dict_to_save = MyTestModule().state_dict()
110

111
            fs_writer = FileSystemWriter(path=path, thread_count=thread_count)
112
            save_state_dict(
113
                state_dict=state_dict_to_save,
114
                storage_writer=fs_writer,
115
                no_dist=True,
116
            )
117

118
            state_dict_to_load_to = MyTestModule().state_dict()
119

120
            with self.assertRaises(AssertionError):
121
                assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
122

123
            # Load from file without any resharding
124
            fs_reader = FileSystemReader(path=path)
125
            load_state_dict(
126
                state_dict=state_dict_to_load_to,
127
                storage_reader=fs_reader,
128
                no_dist=True,
129
            )
130

131
            assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
132

133

134
class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
135
    @property
136
    def world_size(self) -> int:
137
        return 2
138

139
    @with_comms(init_rpc=False, backend="gloo")
140
    @parametrize("thread_count", _THREAD_COUNTS)
141
    def test_read_write_shard_tensor(self, thread_count) -> None:
142
        paths = [tempfile.mkdtemp()]
143
        dist.broadcast_object_list(paths)
144

145
        path = paths[0]
146

147
        # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
148
        spec = ChunkShardingSpec(
149
            dim=0,
150
            placements=[
151
                "rank:0",
152
                "rank:1",
153
            ],
154
        )
155

156
        model_to_save = MyShardedModel1(spec, init_rrefs=False)
157

158
        # Test save
159
        model_to_save._register_state_dict_hook(state_dict_hook)
160
        state_dict_to_save = model_to_save.state_dict()
161

162
        fs_writer = FileSystemWriter(path=path, thread_count=thread_count)
163
        save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer)
164

165
        dist.barrier()
166

167
        # Create a new model
168
        model_to_load = MyShardedModel1(spec, init_rrefs=False)
169
        # This is not the correct hook for loading the state dict
170
        # model_to_load._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True)
171
        model_to_load._register_state_dict_hook(state_dict_hook)
172
        state_dict_to_load_to = model_to_load.state_dict()
173

174
        dist.barrier()
175

176
        with self.assertRaises(AssertionError):
177
            assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
178

179
        # Test load.
180
        fs_reader = FileSystemReader(path=path)
181
        load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader)
182

183
        assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
184
        dist.barrier()
185

186

187
class TestDistributedReshardOnLoad(ShardedTensorTestBase):
188
    @property
189
    def world_size(self) -> int:
190
        return 2
191

192
    def get_file_path(self) -> str:
193
        paths = [tempfile.mkdtemp()] if dist.get_rank() == 0 else [None]
194
        dist.broadcast_object_list(paths)
195
        return paths[0]
196

197
    def load_tensor(self, tensor: ShardedTensor) -> torch.Tensor:
198
        res = torch.zeros(tensor.shape, device="cpu") if dist.get_rank() == 0 else None
199
        tensor.gather(out=res)
200
        return res
201

202
    @with_comms(init_rpc=False, backend="gloo")
203
    @parametrize("thread_count", _THREAD_COUNTS)
204
    def test_load_with_different_shard_plan(self, thread_count) -> None:
205
        path = self.get_file_path()
206

207
        # We hardcode the assumption of how many shards are around
208
        self.assertEqual(self.world_size, dist.get_world_size())
209

210
        specs = [
211
            # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
212
            ChunkShardingSpec(
213
                dim=0,
214
                placements=[
215
                    "rank:0",
216
                    "rank:1",
217
                ],
218
            ),
219
            # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
220
            ChunkShardingSpec(
221
                dim=0,
222
                placements=[
223
                    "rank:0",
224
                    "rank:1",
225
                    "rank:1",
226
                    "rank:0",
227
                ],
228
            ),
229
            # This requires the tensors to be [10, 20]
230
            EnumerableShardingSpec(
231
                shards=[
232
                    ShardMetadata(
233
                        shard_offsets=[0, 0],
234
                        shard_sizes=[2, 20],
235
                        placement="rank:0",
236
                    ),
237
                    ShardMetadata(
238
                        shard_offsets=[2, 0],
239
                        shard_sizes=[1, 20],
240
                        placement="rank:1",
241
                    ),
242
                    ShardMetadata(
243
                        shard_offsets=[3, 0],
244
                        shard_sizes=[3, 20],
245
                        placement="rank:0",
246
                    ),
247
                    ShardMetadata(
248
                        shard_offsets=[6, 0],
249
                        shard_sizes=[3, 20],
250
                        placement="rank:1",
251
                    ),
252
                    ShardMetadata(
253
                        shard_offsets=[9, 0],
254
                        shard_sizes=[1, 20],
255
                        placement="rank:0",
256
                    ),
257
                ]
258
            ),
259
            # This requires the tensors to be [10, 20]
260
            EnumerableShardingSpec(
261
                shards=[
262
                    ShardMetadata(
263
                        shard_offsets=[0, 0],
264
                        shard_sizes=[8, 20],
265
                        placement="rank:1",
266
                    ),
267
                    ShardMetadata(
268
                        shard_offsets=[8, 0],
269
                        shard_sizes=[2, 20],
270
                        placement="rank:0",
271
                    ),
272
                ]
273
            ),
274
        ]
275

276
        for s0 in specs:
277
            for s1 in specs:
278
                if s0 == s1:
279
                    continue
280

281
                dist.barrier()
282

283
                model_to_save = MyShardedModel3(s0)
284
                model_to_save._register_state_dict_hook(state_dict_hook)
285
                state_dict_to_save = model_to_save.state_dict()
286

287
                fs_writer = FileSystemWriter(path=path, thread_count=thread_count)
288
                save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer)
289

290
                dist.barrier()
291

292
                model_to_load = MyShardedModel3(s1)
293
                model_to_load._register_state_dict_hook(state_dict_hook)
294
                state_dict_to_load_to = model_to_load.state_dict()
295
                dist.barrier()
296

297
                fs_reader = FileSystemReader(path=path)
298
                load_state_dict(
299
                    state_dict=state_dict_to_load_to, storage_reader=fs_reader
300
                )
301

302
                dist.barrier()
303
                store_tensor = self.load_tensor(model_to_save.sharded_tensor)
304
                dist.barrier()
305
                load_tensor = self.load_tensor(model_to_load.sharded_tensor)
306

307
                if dist.get_rank() == 0:
308
                    self.assertTrue(
309
                        torch.allclose(store_tensor, load_tensor),
310
                        msg=f"{s0} vs {s1}",
311
                    )
312

313
    @with_comms(init_rpc=False, backend="gloo")
314
    @parametrize("thread_count", _THREAD_COUNTS)
315
    def test_load_rowwise_to_colwise(self, thread_count) -> None:
316
        path = self.get_file_path()
317
        self.assertEqual(self.world_size, dist.get_world_size())
318

319
        # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
320
        src_spec = ChunkShardingSpec(
321
            dim=0,
322
            placements=[
323
                "rank:0",
324
                "rank:1",
325
            ],
326
        )
327

328
        # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
329
        dst_spec = ChunkShardingSpec(
330
            dim=1,
331
            placements=[
332
                "rank:0",
333
                "rank:1",
334
            ],
335
        )
336

337
        model_to_save = MyShardedModel3(src_spec).cuda(dist.get_rank())
338
        model_to_save._register_state_dict_hook(state_dict_hook)
339
        state_dict_to_save = model_to_save.state_dict()
340

341
        fs_writer = FileSystemWriter(path=path, thread_count=thread_count)
342
        save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer)
343

344
        model_to_load = MyShardedModel3(dst_spec).cuda(dist.get_rank())
345
        model_to_load._register_state_dict_hook(state_dict_hook)
346
        state_dict_to_load_to = model_to_load.state_dict()
347

348
        fs_reader = FileSystemReader(path=path)
349

350
        load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader)
351

352
        # We can't use torch.allclose since each ST has a different sharding spec
353
        store_tensor = self.load_tensor(model_to_save.sharded_tensor)
354
        load_tensor = self.load_tensor(model_to_load.sharded_tensor)
355

356
        if dist.get_rank() == 0:
357
            self.assertTrue(torch.allclose(store_tensor, load_tensor))
358

359
    @with_comms(init_rpc=False, backend="gloo")
360
    @parametrize("thread_count", _THREAD_COUNTS)
361
    def test_save_load_bytes(self, thread_count) -> None:
362
        path = self.get_file_path()
363

364
        state_dict_to_save = {"bytes0": [1], "bytes1": "string"}
365

366
        fs_writer = FileSystemWriter(path=path, thread_count=thread_count)
367
        save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer)
368

369
        state_dict_to_load = {"bytes0": [2], "bytes1": "other"}
370

371
        fs_reader = FileSystemReader(path=path)
372
        load_state_dict(state_dict=state_dict_to_load, storage_reader=fs_reader)
373

374
        self.assertEqual([1], state_dict_to_load["bytes0"])
375
        self.assertEqual("string", state_dict_to_load["bytes1"])
376

377
    @with_comms(init_rpc=False, backend="gloo")
378
    @parametrize("thread_count", _THREAD_COUNTS)
379
    def test_switch_between_sharded_tensor_to_tensor(self, thread_count) -> None:
380
        path = self.get_file_path()
381
        tensor_size = 32
382

383
        specs = [
384
            ChunkShardingSpec(
385
                dim=0,
386
                placements=[
387
                    "rank:0",
388
                    "rank:1",
389
                ],
390
            ),
391
            ChunkShardingSpec(
392
                dim=0,
393
                placements=[
394
                    "rank:0",
395
                    "rank:1",
396
                    "rank:1",
397
                    "rank:0",
398
                ],
399
            ),
400
            EnumerableShardingSpec(
401
                shards=[
402
                    ShardMetadata(
403
                        shard_offsets=[0],
404
                        shard_sizes=[8],
405
                        placement="rank:1",
406
                    ),
407
                    ShardMetadata(
408
                        shard_offsets=[8],
409
                        shard_sizes=[tensor_size - 8],
410
                        placement="rank:0",
411
                    ),
412
                ]
413
            ),
414
            EnumerableShardingSpec(
415
                shards=[
416
                    ShardMetadata(
417
                        shard_offsets=[0],
418
                        shard_sizes=[10],
419
                        placement="rank:0",
420
                    ),
421
                    ShardMetadata(
422
                        shard_offsets=[10],
423
                        shard_sizes=[tensor_size - 10],
424
                        placement="rank:1",
425
                    ),
426
                ]
427
            ),
428
        ]
429

430
        for save_spec in specs:
431
            for load_spec in specs:
432
                save_dict = {
433
                    "sharded": sharded_tensor.rand(save_spec, tensor_size),
434
                    "replicated": torch.rand(tensor_size, device="cpu"),
435
                }
436
                dist.broadcast(save_dict["replicated"], src=0)
437

438
                fs_writer = FileSystemWriter(path=path, thread_count=thread_count)
439
                save_state_dict(state_dict=save_dict, storage_writer=fs_writer)
440

441
                # Freaky Friday the tensors
442
                load_dict = {
443
                    "sharded": torch.zeros(tensor_size, device="cpu"),
444
                    "replicated": sharded_tensor.zeros(load_spec, tensor_size),
445
                }
446

447
                fs_reader = FileSystemReader(path=path)
448
                load_state_dict(state_dict=load_dict, storage_reader=fs_reader)
449

450
                save_dict_sharded = self.load_tensor(save_dict["sharded"])
451
                load_dict_replicated = self.load_tensor(load_dict["replicated"])
452

453
                if dist.get_rank() == 0:
454
                    self.assertTrue(
455
                        torch.allclose(save_dict_sharded, load_dict["sharded"]),
456
                        f"save-spec {save_spec} load-spec {load_spec}",
457
                    )
458

459
                    self.assertTrue(
460
                        torch.allclose(save_dict["replicated"], load_dict_replicated),
461
                        f"save-spec {save_spec} load-spec {load_spec}",
462
                    )
463

464

465
instantiate_parametrized_tests(TestDistributedStateDictSaveLoad)
466
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadWithSharedTensor)
467
instantiate_parametrized_tests(TestDistributedReshardOnLoad)
468

469
if __name__ == "__main__":
470
    run_tests()
471

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.