pytorch

Форк
0
/
test_checkpoint.py 
390 строк · 13.1 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import os
4
import sys
5
from typing import cast, List, Optional, Union
6

7
import torch
8
import torch.distributed as dist
9
import torch.futures
10
import torch.nn
11

12
from torch.distributed._shard import sharded_tensor
13

14
from torch.distributed._shard.sharded_tensor import ShardedTensor, state_dict_hook
15
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
16

17
from torch.distributed.checkpoint import (
18
    CheckpointException,
19
    load_state_dict,
20
    save_state_dict,
21
    StorageReader,
22
    StorageWriter,
23
)
24

25
from torch.distributed.checkpoint.default_planner import _create_default_local_metadata
26

27
from torch.distributed.checkpoint.metadata import (
28
    BytesStorageMetadata,
29
    Metadata,
30
    TensorStorageMetadata,
31
)
32

33
from torch.distributed.checkpoint.planner import (
34
    LoadPlan,
35
    LoadPlanner,
36
    SavePlan,
37
    SavePlanner,
38
)
39
from torch.distributed.checkpoint.storage import WriteResult
40
from torch.futures import Future
41
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
42

43
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
44
from torch.testing._internal.distributed._shard.sharded_tensor import (
45
    ShardedTensorTestBase,
46
    with_comms,
47
)
48

49
if TEST_WITH_DEV_DBG_ASAN:
50
    print(
51
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
52
        file=sys.stderr,
53
    )
54
    sys.exit(0)
55

56

57
class TestModule(torch.nn.Module):
58
    def __init__(self) -> None:
59
        super().__init__()
60
        self.sharded: ShardedTensor = sharded_tensor.zeros(self.spec(), 4, 4)
61
        self.regular = torch.nn.Parameter(torch.ones(4, 4))
62
        self.extra_sharded: Optional[ShardedTensor] = None
63
        self.extra_param: Optional[torch.nn.Parameter] = None
64
        self._register_state_dict_hook(state_dict_hook)
65

66
    def spec(self) -> ChunkShardingSpec:
67
        # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
68
        return ChunkShardingSpec(
69
            dim=0,
70
            placements=[
71
                "rank:0/cuda:0",
72
                "rank:1/cuda:1",
73
            ],
74
        )
75

76

77
class TestDistributedCheckpointing(ShardedTensorTestBase):
78
    @property
79
    def world_size(self) -> int:
80
        return 2
81

82
    @with_comms(init_rpc=False)
83
    @skip_if_lt_x_gpu(2)
84
    @requires_nccl()
85
    def test_tensor_metadata_with_missing_rank_spec(self) -> None:
86
        spec = ChunkShardingSpec(
87
            dim=0,
88
            placements=[
89
                "rank:1/cuda:1",
90
            ],
91
        )
92

93
        st = sharded_tensor.zeros(spec, 4, 4, dtype=torch.float64)
94
        mapping = {}
95

96
        md = _create_default_local_metadata({"st": st})
97

98
        st_md = md.state_dict_metadata["st"]
99
        self.assertEqual(1, len(st_md.chunks))
100

101
    @with_comms(init_rpc=False)
102
    @skip_if_lt_x_gpu(2)
103
    @requires_nccl()
104
    def test_default_metadata(self) -> None:
105
        device = f"cuda:{dist.get_rank()}"
106
        spec = ChunkShardingSpec(
107
            dim=0,
108
            placements=[
109
                "rank:0/cuda:0",
110
                "rank:1/cuda:1",
111
            ],
112
        )
113

114
        state_dict = {
115
            "sharded": sharded_tensor.rand(
116
                spec,
117
                (
118
                    10,
119
                    10,
120
                ),
121
            ),
122
            "replicated": torch.rand(4, device=device),
123
            "bytes": [1, 2, 3, 4],
124
        }
125

126
        metadata = _create_default_local_metadata(state_dict)
127
        self.assertTrue("bytes" in metadata.state_dict_metadata)
128
        self.assertIsInstance(
129
            metadata.state_dict_metadata["bytes"], BytesStorageMetadata
130
        )
131

132
        self.assertTrue("replicated" in metadata.state_dict_metadata)
133
        self.assertIsInstance(
134
            metadata.state_dict_metadata["replicated"], TensorStorageMetadata
135
        )
136
        md = metadata.state_dict_metadata["replicated"]
137
        self.assertEqual(md.size, state_dict["replicated"].size())
138
        self.assertEqual(md.properties.dtype, torch.float32)
139
        self.assertEqual(1, len(md.chunks))
140

141
        self.assertTrue("sharded" in metadata.state_dict_metadata)
142
        self.assertIsInstance(
143
            metadata.state_dict_metadata["sharded"], TensorStorageMetadata
144
        )
145
        md = metadata.state_dict_metadata["sharded"]
146
        self.assertEqual(md.properties.dtype, torch.float32)
147
        self.assertEqual(md.size, state_dict["sharded"].size())
148
        self.assertEqual(2, len(md.chunks))
149

150

151
class TestStorageBase:
152
    def __init__(self, fail_conf):
153
        self.fail_conf = fail_conf
154
        self.rank = 0 if not dist.is_initialized() else dist.get_rank()
155

156
    def _get_ranks(self, name):
157
        return self.fail_conf[name] if name in self.fail_conf else None
158

159
    def _fail_rank(self, name):
160
        ranks = self._get_ranks(name)
161
        if ranks is not None and self.rank in ranks:
162
            raise ValueError(f"rank fail {self.rank} for {name}")
163

164
    def _fail_rank_async(self, name, result=None):
165
        ranks = self._get_ranks(name)
166
        fut = Future()
167
        if ranks is not None and self.rank in ranks:
168
            fut.set_exception(ValueError(f"async rank fail {self.rank} for {name}"))
169
        else:
170
            fut.set_result(result)
171
        return fut
172

173

174
class FaultyStorageWriter(TestStorageBase, StorageWriter):
175
    def __init__(self, fail_conf):
176
        super().__init__(fail_conf)
177

178
    def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
179
        return
180

181
    def set_up_storage_writer(self, is_coordinator: bool) -> None:
182
        self._fail_rank("fail_set_up_storage_writer")
183

184
    def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
185
        self._fail_rank("fail_prepare_local_plan")
186
        return plan
187

188
    def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]:
189
        self._fail_rank("fail_prepare_global_plan")
190
        return plans
191

192
    def write_data(
193
        self, plan: SavePlan, planner: SavePlanner
194
    ) -> Future[List[WriteResult]]:
195
        self._fail_rank("fail_write_data")
196
        return self._fail_rank_async("fail_write_data_async", [])
197

198
    def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
199
        self._fail_rank("fail_finish")
200

201
    @classmethod
202
    def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
203
        return True
204

205

206
class FaultyStorageReader(TestStorageBase, StorageReader):
207
    def __init__(self, metadata, fail_conf):
208
        super().__init__(fail_conf)
209
        self.metadata = metadata
210

211
    def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
212
        return
213

214
    def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
215
        self._fail_rank("fail_set_up_storage_reader")
216

217
    def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
218
        self._fail_rank("fail_prepare_local_plan")
219
        return plan
220

221
    def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]:
222
        self._fail_rank("fail_prepare_global_plan")
223
        return plans
224

225
    def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
226
        self._fail_rank("fail_read_data")
227
        return self._fail_rank_async("fail_read_data_async")
228

229
    def read_metadata(self) -> Metadata:
230
        self._fail_rank("fail_read_metadata")
231
        return self.metadata
232

233
    @classmethod
234
    def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
235
        return True
236

237

238
class TestDistributedFailure(ShardedTensorTestBase):
239
    def get_spec(self):
240
        return ChunkShardingSpec(
241
            dim=0,
242
            placements=[f"rank:{r}/cuda:{r}" for r in range(dist.get_world_size())],
243
        )
244

245
    @with_comms(init_rpc=False)
246
    @skip_if_lt_x_gpu(2)
247
    @requires_nccl()
248
    def test_dummy_writer_works(self) -> None:
249
        state_dict = {
250
            "sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
251
            "replicated": torch.rand(10, 10),
252
            "bytes": [1, 2, 3, 4],
253
        }
254

255
        save_state_dict(state_dict, FaultyStorageWriter({}))
256

257
    @with_comms(init_rpc=False)
258
    @skip_if_lt_x_gpu(2)
259
    @requires_nccl()
260
    def test_dummy_reader_works(self) -> None:
261
        state_dict = {
262
            "sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
263
            "replicated": torch.rand(10, 10),
264
            "bytes": [1, 2, 3, 4],
265
        }
266
        metadata = _create_default_local_metadata(state_dict)
267

268
        load_state_dict(state_dict, FaultyStorageReader(metadata, {}))
269

270
    def _test_dist_failure(self, callback, kwargs):
271
        bad_ranks = next(iter(kwargs.values())) if len(kwargs) > 0 else []
272

273
        # Empty bad_ranks means it must work
274
        if len(bad_ranks) == 0:
275
            callback()
276
        else:
277
            with self.assertRaises(CheckpointException) as cm:
278
                callback()
279
            e = cast(CheckpointException, cm.exception)
280
            for rank, wrapped_ex in e.failures.items():
281
                ex = wrapped_ex[0]
282
                self.assertTrue(rank in bad_ranks, msg=f"{rank} did not fail")
283
                if not kwargs.get("ignore_exception_type", False):
284
                    self.assertEqual(ValueError, type(ex), str(ex))
285

286
            failed_ranks = e.failures.keys()
287
            for rank in bad_ranks:
288
                self.assertTrue(
289
                    rank in failed_ranks,
290
                    msg=f"{rank} was supposed to fail was fine",
291
                )
292

293
    def _test_save(self, state_dict, coordinator=0, **kwargs):
294
        no_dist = not dist.is_initialized()
295

296
        def _save():
297
            save_state_dict(
298
                state_dict,
299
                storage_writer=FaultyStorageWriter(kwargs),
300
                coordinator_rank=coordinator,
301
                no_dist=no_dist,
302
            )
303

304
        self._test_dist_failure(_save, kwargs)
305

306
    def _test_load(self, state_dict, coordinator=0, **kwargs):
307
        no_dist = not dist.is_initialized()
308

309
        def _load():
310
            metadata = _create_default_local_metadata(state_dict)
311
            load_state_dict(
312
                state_dict,
313
                storage_reader=FaultyStorageReader(metadata, kwargs),
314
                coordinator_rank=coordinator,
315
                no_dist=no_dist,
316
            )
317

318
        self._test_dist_failure(_load, kwargs)
319

320
    @with_comms(init_rpc=False)
321
    @skip_if_lt_x_gpu(4)
322
    @requires_nccl()
323
    def test_save_error_handling(self) -> None:
324
        state_dict = {
325
            "sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
326
            "replicated": torch.rand(10, 10),
327
            "bytes": [1, 2, 3, 4],
328
        }
329

330
        self._test_save(state_dict, fail_set_up_storage_writer=[0])
331
        self._test_save(state_dict, fail_finish=[0])
332
        self._test_save(state_dict, fail_prepare_global_plan=[0])
333

334
        self._test_save(state_dict, fail_prepare_local_plan=[0])
335
        self._test_save(state_dict, fail_write_data=[2])
336
        self._test_save(state_dict, fail_write_data_async=[3])
337

338
        self._test_save(state_dict, coordinator=1, fail_set_up_storage_writer=[1])
339
        self._test_save(state_dict, coordinator=1, fail_finish=[1])
340

341
    def test_save_error_handling_no_dist(self) -> None:
342
        state_dict = {"replicated": torch.rand(10, 10), "bytes": [1, 2, 3, 4]}
343

344
        self.assertFalse(dist.is_initialized())
345

346
        self._test_save(state_dict, fail_set_up_storage_writer=[0])
347
        self._test_save(state_dict, fail_finish=[0])
348
        self._test_save(state_dict, fail_prepare_global_plan=[0])
349

350
        self._test_save(state_dict, fail_prepare_local_plan=[0])
351
        self._test_save(state_dict, fail_write_data=[0])
352
        self._test_save(state_dict, fail_write_data_async=[0])
353

354
    @with_comms(init_rpc=False)
355
    @skip_if_lt_x_gpu(4)
356
    @requires_nccl()
357
    def test_load_error_handling(self) -> None:
358
        state_dict = {
359
            "sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
360
            "replicated": torch.rand(10, 10),
361
            "bytes": [1, 2, 3, 4],
362
        }
363

364
        self._test_load(state_dict)
365
        self._test_load(state_dict, fail_set_up_storage_reader=[0])
366
        self._test_load(state_dict, fail_prepare_global_plan=[0])
367
        self._test_load(state_dict, fail_read_metadata=[0])
368
        self._test_load(state_dict, fail_prepare_local_plan=[1])
369
        self._test_load(state_dict, fail_read_data=[3])
370
        self._test_load(state_dict, fail_read_data_async=[1])
371

372
        self._test_load(state_dict, coordinator=3, fail_set_up_storage_reader=[0])
373
        self._test_load(state_dict, coordinator=1, fail_read_metadata=[3])
374
        self._test_load(state_dict, coordinator=2, fail_read_data=[0])
375
        self._test_load(state_dict, coordinator=3, fail_read_data_async=[2])
376
        self._test_load(state_dict, coordinator=1, fail_prepare_global_plan=[1])
377

378
    def test_load_error_handling_no_dist(self) -> None:
379
        state_dict = {"replicated": torch.rand(10, 10), "bytes": [1, 2, 3, 4]}
380
        self._test_load(state_dict)
381
        self._test_load(state_dict, fail_set_up_storage_reader=[0])
382
        self._test_load(state_dict, fail_read_metadata=[0])
383
        self._test_load(state_dict, fail_prepare_local_plan=[0])
384
        self._test_load(state_dict, fail_prepare_global_plan=[0])
385
        self._test_load(state_dict, fail_read_data=[0])
386
        self._test_load(state_dict, fail_read_data_async=[0])
387

388

389
if __name__ == "__main__":
390
    run_tests()
391

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.