pytorch
470 строк · 15.9 Кб
1# Owner(s): ["oncall: distributed"]
2
3import sys
4import tempfile
5from typing import Dict
6
7import torch
8import torch.distributed as dist
9from torch.distributed._shard import sharded_tensor
10from torch.distributed._shard.sharded_tensor import ShardedTensor, state_dict_hook
11from torch.distributed._shard.sharding_spec import (
12ChunkShardingSpec,
13EnumerableShardingSpec,
14ShardingSpec,
15ShardMetadata,
16)
17
18from torch.distributed.checkpoint import (
19FileSystemReader,
20FileSystemWriter,
21load_state_dict,
22save_state_dict,
23)
24
25from torch.testing._internal.common_utils import (
26instantiate_parametrized_tests,
27parametrize,
28run_tests,
29TEST_WITH_DEV_DBG_ASAN,
30TestCase,
31)
32from torch.testing._internal.distributed._shard.sharded_tensor import (
33ShardedTensorTestBase,
34with_comms,
35)
36from torch.testing._internal.distributed._shard.sharded_tensor._test_st_common import (
37MyShardedModel1,
38)
39
40
41if TEST_WITH_DEV_DBG_ASAN:
42print(
43"Skip dev-asan as torch + multiprocessing spawn have known issues",
44file=sys.stderr,
45)
46sys.exit(0)
47
48
49_THREAD_COUNTS = {1, 2}
50
51
52def assert_state_dict_equal(
53self: TestCase,
54state_dict_1: Dict[str, torch.Tensor],
55state_dict_2: Dict[str, torch.Tensor],
56) -> bool:
57self.assertEqual(
58len(state_dict_1), len(state_dict_2), "state_dict must be the same size"
59)
60self.assertEqual(
61set(state_dict_1.keys()),
62set(state_dict_2.keys()),
63"state_dict keys do not match",
64)
65
66for key, value_1 in state_dict_1.items():
67value_2 = state_dict_2[key]
68if isinstance(value_1, ShardedTensor):
69for local_shard_1, local_shard_2 in zip(
70value_1.local_shards(), value_2.local_shards()
71):
72self.assertTrue(
73torch.equal(local_shard_1.tensor, local_shard_2.tensor),
74f"Key {key}'s shard does not match",
75)
76elif isinstance(value_1, torch.Tensor):
77self.assertTrue(
78torch.equal(value_1, value_2),
79f"Key {key}'s tensor does not match",
80)
81
82return True
83
84
85class MyTestModule(torch.nn.Module):
86def __init__(self) -> None:
87super().__init__()
88self.linear_1 = torch.nn.Linear(5, 5)
89self.linear_2 = torch.nn.Linear(5, 1)
90self.emb = torch.nn.EmbeddingBag(5, 10)
91
92
93# The ShardedModels are borrowed from test/distributed/_sharded_tensor/test_sharded_tensor.py
94class MyShardedModel3(torch.nn.Module):
95def __init__(
96self,
97spec: ShardingSpec,
98) -> None:
99super().__init__()
100self.sharded_tensor: ShardedTensor = sharded_tensor.rand(
101spec, 10, 20, init_rrefs=False
102)
103
104
105class TestDistributedStateDictSaveLoad(TestCase):
106@parametrize("thread_count", _THREAD_COUNTS)
107def test_read_write_only_tensor(self, thread_count) -> None:
108with tempfile.TemporaryDirectory() as path:
109state_dict_to_save = MyTestModule().state_dict()
110
111fs_writer = FileSystemWriter(path=path, thread_count=thread_count)
112save_state_dict(
113state_dict=state_dict_to_save,
114storage_writer=fs_writer,
115no_dist=True,
116)
117
118state_dict_to_load_to = MyTestModule().state_dict()
119
120with self.assertRaises(AssertionError):
121assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
122
123# Load from file without any resharding
124fs_reader = FileSystemReader(path=path)
125load_state_dict(
126state_dict=state_dict_to_load_to,
127storage_reader=fs_reader,
128no_dist=True,
129)
130
131assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
132
133
134class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
135@property
136def world_size(self) -> int:
137return 2
138
139@with_comms(init_rpc=False, backend="gloo")
140@parametrize("thread_count", _THREAD_COUNTS)
141def test_read_write_shard_tensor(self, thread_count) -> None:
142paths = [tempfile.mkdtemp()]
143dist.broadcast_object_list(paths)
144
145path = paths[0]
146
147# pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
148spec = ChunkShardingSpec(
149dim=0,
150placements=[
151"rank:0",
152"rank:1",
153],
154)
155
156model_to_save = MyShardedModel1(spec, init_rrefs=False)
157
158# Test save
159model_to_save._register_state_dict_hook(state_dict_hook)
160state_dict_to_save = model_to_save.state_dict()
161
162fs_writer = FileSystemWriter(path=path, thread_count=thread_count)
163save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer)
164
165dist.barrier()
166
167# Create a new model
168model_to_load = MyShardedModel1(spec, init_rrefs=False)
169# This is not the correct hook for loading the state dict
170# model_to_load._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True)
171model_to_load._register_state_dict_hook(state_dict_hook)
172state_dict_to_load_to = model_to_load.state_dict()
173
174dist.barrier()
175
176with self.assertRaises(AssertionError):
177assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
178
179# Test load.
180fs_reader = FileSystemReader(path=path)
181load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader)
182
183assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
184dist.barrier()
185
186
187class TestDistributedReshardOnLoad(ShardedTensorTestBase):
188@property
189def world_size(self) -> int:
190return 2
191
192def get_file_path(self) -> str:
193paths = [tempfile.mkdtemp()] if dist.get_rank() == 0 else [None]
194dist.broadcast_object_list(paths)
195return paths[0]
196
197def load_tensor(self, tensor: ShardedTensor) -> torch.Tensor:
198res = torch.zeros(tensor.shape, device="cpu") if dist.get_rank() == 0 else None
199tensor.gather(out=res)
200return res
201
202@with_comms(init_rpc=False, backend="gloo")
203@parametrize("thread_count", _THREAD_COUNTS)
204def test_load_with_different_shard_plan(self, thread_count) -> None:
205path = self.get_file_path()
206
207# We hardcode the assumption of how many shards are around
208self.assertEqual(self.world_size, dist.get_world_size())
209
210specs = [
211# pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
212ChunkShardingSpec(
213dim=0,
214placements=[
215"rank:0",
216"rank:1",
217],
218),
219# pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
220ChunkShardingSpec(
221dim=0,
222placements=[
223"rank:0",
224"rank:1",
225"rank:1",
226"rank:0",
227],
228),
229# This requires the tensors to be [10, 20]
230EnumerableShardingSpec(
231shards=[
232ShardMetadata(
233shard_offsets=[0, 0],
234shard_sizes=[2, 20],
235placement="rank:0",
236),
237ShardMetadata(
238shard_offsets=[2, 0],
239shard_sizes=[1, 20],
240placement="rank:1",
241),
242ShardMetadata(
243shard_offsets=[3, 0],
244shard_sizes=[3, 20],
245placement="rank:0",
246),
247ShardMetadata(
248shard_offsets=[6, 0],
249shard_sizes=[3, 20],
250placement="rank:1",
251),
252ShardMetadata(
253shard_offsets=[9, 0],
254shard_sizes=[1, 20],
255placement="rank:0",
256),
257]
258),
259# This requires the tensors to be [10, 20]
260EnumerableShardingSpec(
261shards=[
262ShardMetadata(
263shard_offsets=[0, 0],
264shard_sizes=[8, 20],
265placement="rank:1",
266),
267ShardMetadata(
268shard_offsets=[8, 0],
269shard_sizes=[2, 20],
270placement="rank:0",
271),
272]
273),
274]
275
276for s0 in specs:
277for s1 in specs:
278if s0 == s1:
279continue
280
281dist.barrier()
282
283model_to_save = MyShardedModel3(s0)
284model_to_save._register_state_dict_hook(state_dict_hook)
285state_dict_to_save = model_to_save.state_dict()
286
287fs_writer = FileSystemWriter(path=path, thread_count=thread_count)
288save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer)
289
290dist.barrier()
291
292model_to_load = MyShardedModel3(s1)
293model_to_load._register_state_dict_hook(state_dict_hook)
294state_dict_to_load_to = model_to_load.state_dict()
295dist.barrier()
296
297fs_reader = FileSystemReader(path=path)
298load_state_dict(
299state_dict=state_dict_to_load_to, storage_reader=fs_reader
300)
301
302dist.barrier()
303store_tensor = self.load_tensor(model_to_save.sharded_tensor)
304dist.barrier()
305load_tensor = self.load_tensor(model_to_load.sharded_tensor)
306
307if dist.get_rank() == 0:
308self.assertTrue(
309torch.allclose(store_tensor, load_tensor),
310msg=f"{s0} vs {s1}",
311)
312
313@with_comms(init_rpc=False, backend="gloo")
314@parametrize("thread_count", _THREAD_COUNTS)
315def test_load_rowwise_to_colwise(self, thread_count) -> None:
316path = self.get_file_path()
317self.assertEqual(self.world_size, dist.get_world_size())
318
319# pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
320src_spec = ChunkShardingSpec(
321dim=0,
322placements=[
323"rank:0",
324"rank:1",
325],
326)
327
328# pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
329dst_spec = ChunkShardingSpec(
330dim=1,
331placements=[
332"rank:0",
333"rank:1",
334],
335)
336
337model_to_save = MyShardedModel3(src_spec).cuda(dist.get_rank())
338model_to_save._register_state_dict_hook(state_dict_hook)
339state_dict_to_save = model_to_save.state_dict()
340
341fs_writer = FileSystemWriter(path=path, thread_count=thread_count)
342save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer)
343
344model_to_load = MyShardedModel3(dst_spec).cuda(dist.get_rank())
345model_to_load._register_state_dict_hook(state_dict_hook)
346state_dict_to_load_to = model_to_load.state_dict()
347
348fs_reader = FileSystemReader(path=path)
349
350load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader)
351
352# We can't use torch.allclose since each ST has a different sharding spec
353store_tensor = self.load_tensor(model_to_save.sharded_tensor)
354load_tensor = self.load_tensor(model_to_load.sharded_tensor)
355
356if dist.get_rank() == 0:
357self.assertTrue(torch.allclose(store_tensor, load_tensor))
358
359@with_comms(init_rpc=False, backend="gloo")
360@parametrize("thread_count", _THREAD_COUNTS)
361def test_save_load_bytes(self, thread_count) -> None:
362path = self.get_file_path()
363
364state_dict_to_save = {"bytes0": [1], "bytes1": "string"}
365
366fs_writer = FileSystemWriter(path=path, thread_count=thread_count)
367save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer)
368
369state_dict_to_load = {"bytes0": [2], "bytes1": "other"}
370
371fs_reader = FileSystemReader(path=path)
372load_state_dict(state_dict=state_dict_to_load, storage_reader=fs_reader)
373
374self.assertEqual([1], state_dict_to_load["bytes0"])
375self.assertEqual("string", state_dict_to_load["bytes1"])
376
377@with_comms(init_rpc=False, backend="gloo")
378@parametrize("thread_count", _THREAD_COUNTS)
379def test_switch_between_sharded_tensor_to_tensor(self, thread_count) -> None:
380path = self.get_file_path()
381tensor_size = 32
382
383specs = [
384ChunkShardingSpec(
385dim=0,
386placements=[
387"rank:0",
388"rank:1",
389],
390),
391ChunkShardingSpec(
392dim=0,
393placements=[
394"rank:0",
395"rank:1",
396"rank:1",
397"rank:0",
398],
399),
400EnumerableShardingSpec(
401shards=[
402ShardMetadata(
403shard_offsets=[0],
404shard_sizes=[8],
405placement="rank:1",
406),
407ShardMetadata(
408shard_offsets=[8],
409shard_sizes=[tensor_size - 8],
410placement="rank:0",
411),
412]
413),
414EnumerableShardingSpec(
415shards=[
416ShardMetadata(
417shard_offsets=[0],
418shard_sizes=[10],
419placement="rank:0",
420),
421ShardMetadata(
422shard_offsets=[10],
423shard_sizes=[tensor_size - 10],
424placement="rank:1",
425),
426]
427),
428]
429
430for save_spec in specs:
431for load_spec in specs:
432save_dict = {
433"sharded": sharded_tensor.rand(save_spec, tensor_size),
434"replicated": torch.rand(tensor_size, device="cpu"),
435}
436dist.broadcast(save_dict["replicated"], src=0)
437
438fs_writer = FileSystemWriter(path=path, thread_count=thread_count)
439save_state_dict(state_dict=save_dict, storage_writer=fs_writer)
440
441# Freaky Friday the tensors
442load_dict = {
443"sharded": torch.zeros(tensor_size, device="cpu"),
444"replicated": sharded_tensor.zeros(load_spec, tensor_size),
445}
446
447fs_reader = FileSystemReader(path=path)
448load_state_dict(state_dict=load_dict, storage_reader=fs_reader)
449
450save_dict_sharded = self.load_tensor(save_dict["sharded"])
451load_dict_replicated = self.load_tensor(load_dict["replicated"])
452
453if dist.get_rank() == 0:
454self.assertTrue(
455torch.allclose(save_dict_sharded, load_dict["sharded"]),
456f"save-spec {save_spec} load-spec {load_spec}",
457)
458
459self.assertTrue(
460torch.allclose(save_dict["replicated"], load_dict_replicated),
461f"save-spec {save_spec} load-spec {load_spec}",
462)
463
464
465instantiate_parametrized_tests(TestDistributedStateDictSaveLoad)
466instantiate_parametrized_tests(TestDistributedStateDictSaveLoadWithSharedTensor)
467instantiate_parametrized_tests(TestDistributedReshardOnLoad)
468
469if __name__ == "__main__":
470run_tests()
471