pytorch
390 строк · 13.1 Кб
1# Owner(s): ["oncall: distributed"]
2
3import os
4import sys
5from typing import cast, List, Optional, Union
6
7import torch
8import torch.distributed as dist
9import torch.futures
10import torch.nn
11
12from torch.distributed._shard import sharded_tensor
13
14from torch.distributed._shard.sharded_tensor import ShardedTensor, state_dict_hook
15from torch.distributed._shard.sharding_spec import ChunkShardingSpec
16
17from torch.distributed.checkpoint import (
18CheckpointException,
19load_state_dict,
20save_state_dict,
21StorageReader,
22StorageWriter,
23)
24
25from torch.distributed.checkpoint.default_planner import _create_default_local_metadata
26
27from torch.distributed.checkpoint.metadata import (
28BytesStorageMetadata,
29Metadata,
30TensorStorageMetadata,
31)
32
33from torch.distributed.checkpoint.planner import (
34LoadPlan,
35LoadPlanner,
36SavePlan,
37SavePlanner,
38)
39from torch.distributed.checkpoint.storage import WriteResult
40from torch.futures import Future
41from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
42
43from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
44from torch.testing._internal.distributed._shard.sharded_tensor import (
45ShardedTensorTestBase,
46with_comms,
47)
48
49if TEST_WITH_DEV_DBG_ASAN:
50print(
51"Skip dev-asan as torch + multiprocessing spawn have known issues",
52file=sys.stderr,
53)
54sys.exit(0)
55
56
57class TestModule(torch.nn.Module):
58def __init__(self) -> None:
59super().__init__()
60self.sharded: ShardedTensor = sharded_tensor.zeros(self.spec(), 4, 4)
61self.regular = torch.nn.Parameter(torch.ones(4, 4))
62self.extra_sharded: Optional[ShardedTensor] = None
63self.extra_param: Optional[torch.nn.Parameter] = None
64self._register_state_dict_hook(state_dict_hook)
65
66def spec(self) -> ChunkShardingSpec:
67# pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
68return ChunkShardingSpec(
69dim=0,
70placements=[
71"rank:0/cuda:0",
72"rank:1/cuda:1",
73],
74)
75
76
77class TestDistributedCheckpointing(ShardedTensorTestBase):
78@property
79def world_size(self) -> int:
80return 2
81
82@with_comms(init_rpc=False)
83@skip_if_lt_x_gpu(2)
84@requires_nccl()
85def test_tensor_metadata_with_missing_rank_spec(self) -> None:
86spec = ChunkShardingSpec(
87dim=0,
88placements=[
89"rank:1/cuda:1",
90],
91)
92
93st = sharded_tensor.zeros(spec, 4, 4, dtype=torch.float64)
94mapping = {}
95
96md = _create_default_local_metadata({"st": st})
97
98st_md = md.state_dict_metadata["st"]
99self.assertEqual(1, len(st_md.chunks))
100
101@with_comms(init_rpc=False)
102@skip_if_lt_x_gpu(2)
103@requires_nccl()
104def test_default_metadata(self) -> None:
105device = f"cuda:{dist.get_rank()}"
106spec = ChunkShardingSpec(
107dim=0,
108placements=[
109"rank:0/cuda:0",
110"rank:1/cuda:1",
111],
112)
113
114state_dict = {
115"sharded": sharded_tensor.rand(
116spec,
117(
11810,
11910,
120),
121),
122"replicated": torch.rand(4, device=device),
123"bytes": [1, 2, 3, 4],
124}
125
126metadata = _create_default_local_metadata(state_dict)
127self.assertTrue("bytes" in metadata.state_dict_metadata)
128self.assertIsInstance(
129metadata.state_dict_metadata["bytes"], BytesStorageMetadata
130)
131
132self.assertTrue("replicated" in metadata.state_dict_metadata)
133self.assertIsInstance(
134metadata.state_dict_metadata["replicated"], TensorStorageMetadata
135)
136md = metadata.state_dict_metadata["replicated"]
137self.assertEqual(md.size, state_dict["replicated"].size())
138self.assertEqual(md.properties.dtype, torch.float32)
139self.assertEqual(1, len(md.chunks))
140
141self.assertTrue("sharded" in metadata.state_dict_metadata)
142self.assertIsInstance(
143metadata.state_dict_metadata["sharded"], TensorStorageMetadata
144)
145md = metadata.state_dict_metadata["sharded"]
146self.assertEqual(md.properties.dtype, torch.float32)
147self.assertEqual(md.size, state_dict["sharded"].size())
148self.assertEqual(2, len(md.chunks))
149
150
151class TestStorageBase:
152def __init__(self, fail_conf):
153self.fail_conf = fail_conf
154self.rank = 0 if not dist.is_initialized() else dist.get_rank()
155
156def _get_ranks(self, name):
157return self.fail_conf[name] if name in self.fail_conf else None
158
159def _fail_rank(self, name):
160ranks = self._get_ranks(name)
161if ranks is not None and self.rank in ranks:
162raise ValueError(f"rank fail {self.rank} for {name}")
163
164def _fail_rank_async(self, name, result=None):
165ranks = self._get_ranks(name)
166fut = Future()
167if ranks is not None and self.rank in ranks:
168fut.set_exception(ValueError(f"async rank fail {self.rank} for {name}"))
169else:
170fut.set_result(result)
171return fut
172
173
174class FaultyStorageWriter(TestStorageBase, StorageWriter):
175def __init__(self, fail_conf):
176super().__init__(fail_conf)
177
178def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
179return
180
181def set_up_storage_writer(self, is_coordinator: bool) -> None:
182self._fail_rank("fail_set_up_storage_writer")
183
184def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
185self._fail_rank("fail_prepare_local_plan")
186return plan
187
188def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]:
189self._fail_rank("fail_prepare_global_plan")
190return plans
191
192def write_data(
193self, plan: SavePlan, planner: SavePlanner
194) -> Future[List[WriteResult]]:
195self._fail_rank("fail_write_data")
196return self._fail_rank_async("fail_write_data_async", [])
197
198def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
199self._fail_rank("fail_finish")
200
201@classmethod
202def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
203return True
204
205
206class FaultyStorageReader(TestStorageBase, StorageReader):
207def __init__(self, metadata, fail_conf):
208super().__init__(fail_conf)
209self.metadata = metadata
210
211def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
212return
213
214def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
215self._fail_rank("fail_set_up_storage_reader")
216
217def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
218self._fail_rank("fail_prepare_local_plan")
219return plan
220
221def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]:
222self._fail_rank("fail_prepare_global_plan")
223return plans
224
225def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
226self._fail_rank("fail_read_data")
227return self._fail_rank_async("fail_read_data_async")
228
229def read_metadata(self) -> Metadata:
230self._fail_rank("fail_read_metadata")
231return self.metadata
232
233@classmethod
234def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
235return True
236
237
238class TestDistributedFailure(ShardedTensorTestBase):
239def get_spec(self):
240return ChunkShardingSpec(
241dim=0,
242placements=[f"rank:{r}/cuda:{r}" for r in range(dist.get_world_size())],
243)
244
245@with_comms(init_rpc=False)
246@skip_if_lt_x_gpu(2)
247@requires_nccl()
248def test_dummy_writer_works(self) -> None:
249state_dict = {
250"sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
251"replicated": torch.rand(10, 10),
252"bytes": [1, 2, 3, 4],
253}
254
255save_state_dict(state_dict, FaultyStorageWriter({}))
256
257@with_comms(init_rpc=False)
258@skip_if_lt_x_gpu(2)
259@requires_nccl()
260def test_dummy_reader_works(self) -> None:
261state_dict = {
262"sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
263"replicated": torch.rand(10, 10),
264"bytes": [1, 2, 3, 4],
265}
266metadata = _create_default_local_metadata(state_dict)
267
268load_state_dict(state_dict, FaultyStorageReader(metadata, {}))
269
270def _test_dist_failure(self, callback, kwargs):
271bad_ranks = next(iter(kwargs.values())) if len(kwargs) > 0 else []
272
273# Empty bad_ranks means it must work
274if len(bad_ranks) == 0:
275callback()
276else:
277with self.assertRaises(CheckpointException) as cm:
278callback()
279e = cast(CheckpointException, cm.exception)
280for rank, wrapped_ex in e.failures.items():
281ex = wrapped_ex[0]
282self.assertTrue(rank in bad_ranks, msg=f"{rank} did not fail")
283if not kwargs.get("ignore_exception_type", False):
284self.assertEqual(ValueError, type(ex), str(ex))
285
286failed_ranks = e.failures.keys()
287for rank in bad_ranks:
288self.assertTrue(
289rank in failed_ranks,
290msg=f"{rank} was supposed to fail was fine",
291)
292
293def _test_save(self, state_dict, coordinator=0, **kwargs):
294no_dist = not dist.is_initialized()
295
296def _save():
297save_state_dict(
298state_dict,
299storage_writer=FaultyStorageWriter(kwargs),
300coordinator_rank=coordinator,
301no_dist=no_dist,
302)
303
304self._test_dist_failure(_save, kwargs)
305
306def _test_load(self, state_dict, coordinator=0, **kwargs):
307no_dist = not dist.is_initialized()
308
309def _load():
310metadata = _create_default_local_metadata(state_dict)
311load_state_dict(
312state_dict,
313storage_reader=FaultyStorageReader(metadata, kwargs),
314coordinator_rank=coordinator,
315no_dist=no_dist,
316)
317
318self._test_dist_failure(_load, kwargs)
319
320@with_comms(init_rpc=False)
321@skip_if_lt_x_gpu(4)
322@requires_nccl()
323def test_save_error_handling(self) -> None:
324state_dict = {
325"sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
326"replicated": torch.rand(10, 10),
327"bytes": [1, 2, 3, 4],
328}
329
330self._test_save(state_dict, fail_set_up_storage_writer=[0])
331self._test_save(state_dict, fail_finish=[0])
332self._test_save(state_dict, fail_prepare_global_plan=[0])
333
334self._test_save(state_dict, fail_prepare_local_plan=[0])
335self._test_save(state_dict, fail_write_data=[2])
336self._test_save(state_dict, fail_write_data_async=[3])
337
338self._test_save(state_dict, coordinator=1, fail_set_up_storage_writer=[1])
339self._test_save(state_dict, coordinator=1, fail_finish=[1])
340
341def test_save_error_handling_no_dist(self) -> None:
342state_dict = {"replicated": torch.rand(10, 10), "bytes": [1, 2, 3, 4]}
343
344self.assertFalse(dist.is_initialized())
345
346self._test_save(state_dict, fail_set_up_storage_writer=[0])
347self._test_save(state_dict, fail_finish=[0])
348self._test_save(state_dict, fail_prepare_global_plan=[0])
349
350self._test_save(state_dict, fail_prepare_local_plan=[0])
351self._test_save(state_dict, fail_write_data=[0])
352self._test_save(state_dict, fail_write_data_async=[0])
353
354@with_comms(init_rpc=False)
355@skip_if_lt_x_gpu(4)
356@requires_nccl()
357def test_load_error_handling(self) -> None:
358state_dict = {
359"sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
360"replicated": torch.rand(10, 10),
361"bytes": [1, 2, 3, 4],
362}
363
364self._test_load(state_dict)
365self._test_load(state_dict, fail_set_up_storage_reader=[0])
366self._test_load(state_dict, fail_prepare_global_plan=[0])
367self._test_load(state_dict, fail_read_metadata=[0])
368self._test_load(state_dict, fail_prepare_local_plan=[1])
369self._test_load(state_dict, fail_read_data=[3])
370self._test_load(state_dict, fail_read_data_async=[1])
371
372self._test_load(state_dict, coordinator=3, fail_set_up_storage_reader=[0])
373self._test_load(state_dict, coordinator=1, fail_read_metadata=[3])
374self._test_load(state_dict, coordinator=2, fail_read_data=[0])
375self._test_load(state_dict, coordinator=3, fail_read_data_async=[2])
376self._test_load(state_dict, coordinator=1, fail_prepare_global_plan=[1])
377
378def test_load_error_handling_no_dist(self) -> None:
379state_dict = {"replicated": torch.rand(10, 10), "bytes": [1, 2, 3, 4]}
380self._test_load(state_dict)
381self._test_load(state_dict, fail_set_up_storage_reader=[0])
382self._test_load(state_dict, fail_read_metadata=[0])
383self._test_load(state_dict, fail_prepare_local_plan=[0])
384self._test_load(state_dict, fail_prepare_global_plan=[0])
385self._test_load(state_dict, fail_read_data=[0])
386self._test_load(state_dict, fail_read_data_async=[0])
387
388
389if __name__ == "__main__":
390run_tests()
391