pytorch
341 строка · 14.3 Кб
1# Owner(s): ["oncall: distributed"]
2
3import sys
4
5import torch
6
7from torch.distributed._shard.sharded_tensor import (
8Shard,
9ShardedTensor,
10ShardedTensorMetadata,
11ShardMetadata,
12)
13from torch.distributed._shard.sharded_tensor.metadata import (
14TensorProperties as TensorProperties_Shard,
15)
16from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans
17
18from torch.distributed.checkpoint.default_planner import (
19_create_default_local_metadata,
20create_default_global_save_plan,
21create_default_local_load_plan,
22create_default_local_save_plan,
23)
24from torch.distributed.checkpoint.metadata import (
25BytesStorageMetadata,
26ChunkStorageMetadata,
27MetadataIndex,
28TensorProperties,
29TensorStorageMetadata,
30)
31from torch.distributed.checkpoint.planner import LoadItemType, WriteItemType
32
33from torch.distributed.checkpoint.planner_helpers import (
34create_read_items_for_chunk_list,
35)
36
37from torch.testing._internal.common_utils import (
38run_tests,
39TEST_WITH_DEV_DBG_ASAN,
40TestCase,
41)
42
43from torch.testing._internal.distributed.distributed_utils import (
44with_dist,
45with_fake_comms,
46)
47
48
49if TEST_WITH_DEV_DBG_ASAN:
50print(
51"Skip dev-asan as torch + multiprocessing spawn have known issues",
52file=sys.stderr,
53)
54sys.exit(0)
55
56
57def create_sharded_tensor(rank, world_size, shards_per_rank, shard_size=8):
58shards_metadata = []
59local_shards = []
60for idx in range(0, world_size * shards_per_rank):
61shard_rank = idx // shards_per_rank
62shard_md = ShardMetadata(
63shard_offsets=[idx * shard_size],
64shard_sizes=[shard_size],
65placement=f"rank:{shard_rank}/cpu",
66)
67shards_metadata.append(shard_md)
68if shard_rank == rank:
69shard = Shard.from_tensor_and_offsets(
70torch.rand(*shard_md.shard_sizes),
71shard_offsets=shard_md.shard_offsets,
72rank=rank,
73)
74local_shards.append(shard)
75
76sharded_tensor_md = ShardedTensorMetadata(
77shards_metadata=shards_metadata,
78size=torch.Size([shard_size * len(shards_metadata)]),
79tensor_properties=TensorProperties_Shard.create_from_tensor(torch.zeros(1)),
80)
81
82return ShardedTensor._init_from_local_shards_and_global_metadata(
83local_shards=local_shards, sharded_tensor_metadata=sharded_tensor_md
84)
85
86
87class TestSavePlan(TestCase):
88@with_fake_comms(rank=1, world_size=4)
89def test_local_plan(self):
90tensor = torch.rand(10)
91val = [1, 2, 3]
92st = create_sharded_tensor(rank=1, world_size=4, shards_per_rank=1)
93state_dict = {"tensor": tensor, "value": val, "st": st}
94plan = create_default_local_save_plan(state_dict, False)
95self.assertEqual(2, len(plan.items))
96wi = plan.items[0]
97self.assertEqual(wi.index, MetadataIndex("tensor", [0]))
98self.assertEqual(wi.type, WriteItemType.TENSOR)
99self.assertEqual(wi.tensor_data.size, tensor.size())
100self.assertEqual(
101wi.tensor_data.properties,
102TensorProperties.create_from_tensor(torch.zeros(1)),
103)
104self.assertEqual(wi.tensor_data.chunk.offsets, torch.Size([0]))
105self.assertEqual(wi.tensor_data.chunk.sizes, torch.Size([10]))
106
107st_wi = plan.items[1]
108self.assertEqual(st_wi.index, MetadataIndex("st", [8]))
109self.assertEqual(st_wi.type, WriteItemType.SHARD)
110self.assertEqual(st_wi.tensor_data.size, st.size())
111self.assertEqual(
112st_wi.tensor_data.properties,
113TensorProperties.create_from_tensor(torch.zeros(1)),
114)
115self.assertEqual(st_wi.tensor_data.chunk.offsets, torch.Size([8]))
116self.assertEqual(st_wi.tensor_data.chunk.sizes, torch.Size([8]))
117
118# Coordinator rank, should include replicated items as well
119plan = create_default_local_save_plan(state_dict, True)
120self.assertEqual(3, len(plan.items))
121
122tensor_wi = next(wi for wi in plan.items if wi.type == WriteItemType.TENSOR)
123self.assertEqual(tensor_wi.index, MetadataIndex("tensor", [0]))
124self.assertEqual(tensor_wi.tensor_data.size, tensor.size())
125self.assertEqual(
126tensor_wi.tensor_data.properties,
127TensorProperties.create_from_tensor(tensor),
128)
129self.assertEqual(tensor_wi.tensor_data.chunk.offsets, torch.Size([0]))
130self.assertEqual(tensor_wi.tensor_data.chunk.sizes, torch.Size([10]))
131
132bytes_wi = next(wi for wi in plan.items if wi.type == WriteItemType.BYTE_IO)
133self.assertEqual(bytes_wi.index, MetadataIndex("value"))
134self.assertIsNone(bytes_wi.tensor_data)
135
136def test_global_plan(self):
137def create_data(rank):
138with with_dist(rank=rank, world_size=4):
139tensor = torch.rand(10)
140val = [1, 2, 3]
141st = create_sharded_tensor(rank=rank, world_size=4, shards_per_rank=1)
142state_dict = {"tensor": tensor, "value": val, "st": st}
143return create_default_local_save_plan(state_dict, rank == 0)
144
145all_plans = [create_data(0), create_data(1), create_data(2), create_data(3)]
146all_plans = dedup_save_plans(all_plans)
147final_plans, metadata = create_default_global_save_plan(all_plans=all_plans)
148
149# The default global plan updates all indexes to include hints
150for new_plan, old_plan in zip(final_plans, all_plans):
151for new_item, old_item in zip(new_plan.items, old_plan.items):
152self.assertEqual(new_item.index, old_item.index)
153self.assertEqual(new_item.type, old_item.type)
154self.assertEqual(new_item.tensor_data, old_item.tensor_data)
155self.assertIn(new_item.index.fqn, metadata.state_dict_metadata)
156
157item_md = metadata.state_dict_metadata[new_item.index.fqn]
158if new_item.type == WriteItemType.BYTE_IO:
159self.assertTrue(isinstance(item_md, BytesStorageMetadata))
160else:
161self.assertTrue(isinstance(item_md, TensorStorageMetadata))
162self.assertEqual(item_md.size, old_item.tensor_data.size)
163self.assertEqual(
164item_md.properties, old_item.tensor_data.properties
165)
166
167self.assertIsNotNone(new_item.index.index)
168# Make sure the hint is correct
169self.assertEqual(
170item_md.chunks[new_item.index.index], old_item.tensor_data.chunk
171)
172
173def test_local_load_plan(self):
174def create_state_dict(rank):
175with with_dist(rank=rank, world_size=4):
176tensor = torch.rand(10)
177val = [1, 2, 3]
178st = create_sharded_tensor(rank=rank, world_size=4, shards_per_rank=1)
179return {"tensor": tensor, "value": val, "st": st}
180
181state_dict = create_state_dict(1)
182metadata = _create_default_local_metadata(state_dict)
183
184load_plan = create_default_local_load_plan(state_dict, metadata)
185# This will create 3 entries
186self.assertEqual(3, len(load_plan.items))
187st_item = next(ri for ri in load_plan.items if ri.dest_index.fqn == "st")
188tensor_item = next(
189ri for ri in load_plan.items if ri.dest_index.fqn == "tensor"
190)
191bytes_item = next(ri for ri in load_plan.items if ri.dest_index.fqn == "value")
192
193self.assertEqual(st_item.type, LoadItemType.TENSOR)
194# This is an exact copy
195self.assertEqual(st_item.dest_index, MetadataIndex("st", [8]))
196self.assertEqual(st_item.dest_offsets, torch.Size([0]))
197self.assertEqual(st_item.storage_index, MetadataIndex("st", [8]))
198self.assertEqual(st_item.storage_offsets, torch.Size([0]))
199self.assertEqual(st_item.lengths, torch.Size([8]))
200
201self.assertEqual(tensor_item.type, LoadItemType.TENSOR)
202self.assertEqual(tensor_item.dest_index, MetadataIndex("tensor", [0]))
203self.assertEqual(tensor_item.dest_offsets, torch.Size([0]))
204self.assertEqual(tensor_item.storage_index, MetadataIndex("tensor", [0]))
205self.assertEqual(tensor_item.storage_offsets, torch.Size([0]))
206self.assertEqual(tensor_item.lengths, torch.Size([10]))
207
208self.assertEqual(bytes_item.type, LoadItemType.BYTE_IO)
209self.assertEqual(bytes_item.dest_index, MetadataIndex("value"))
210
211def test_load_with_resharding(self):
212def create_state_dict(rank, world_size):
213with with_dist(rank=rank, world_size=world_size):
214return {
215"st": create_sharded_tensor(
216rank=rank,
217world_size=world_size,
218shards_per_rank=1,
219shard_size=128 // world_size,
220)
221}
222
223# Rank 1 has a 16 bytes shard from [16, 32[
224world8_state_dict = create_state_dict(rank=1, world_size=8)
225world8_metadata = _create_default_local_metadata(world8_state_dict)
226
227# Rank 1 has a 32 bytes shard from [32, 64[
228world4_state_dict = create_state_dict(rank=1, world_size=4)
229world4_metadata = _create_default_local_metadata(world4_state_dict)
230
231# First scenario, going from world=8 to world=4, need to load 2 shards
232# Each 4-world shard has 32 elements, so it needs to load 2 shards
233load_plan = create_default_local_load_plan(world4_state_dict, world8_metadata)
234self.assertEqual(2, len(load_plan.items))
235low_ri = next(
236ri for ri in load_plan.items if ri.dest_offsets == torch.Size([0])
237)
238high_ri = next(
239ri for ri in load_plan.items if ri.dest_offsets == torch.Size([16])
240)
241
242self.assertEqual(low_ri.storage_index, MetadataIndex("st", [32]))
243self.assertEqual(low_ri.storage_offsets, torch.Size([0]))
244self.assertEqual(low_ri.dest_index, MetadataIndex("st", [32]))
245self.assertEqual(low_ri.dest_offsets, torch.Size([0]))
246self.assertEqual(low_ri.lengths, torch.Size([16]))
247
248self.assertEqual(high_ri.storage_index, MetadataIndex("st", [48]))
249self.assertEqual(high_ri.storage_offsets, torch.Size([0]))
250self.assertEqual(high_ri.dest_index, MetadataIndex("st", [32]))
251self.assertEqual(high_ri.dest_offsets, torch.Size([16]))
252self.assertEqual(high_ri.lengths, torch.Size([16]))
253
254# Second scenario, going from world=4 to world=8, need to load half of 1 shard
255# rank1 on 8-world needs to load the upper half of the rank0 4-world shard
256load_plan = create_default_local_load_plan(world8_state_dict, world4_metadata)
257self.assertEqual(1, len(load_plan.items))
258ri = load_plan.items[0]
259self.assertEqual(ri.storage_index, MetadataIndex("st", [0]))
260self.assertEqual(ri.storage_offsets, torch.Size([16]))
261self.assertEqual(ri.dest_index, MetadataIndex("st", [16]))
262self.assertEqual(ri.dest_offsets, torch.Size([0]))
263self.assertEqual(ri.lengths, torch.Size([16]))
264
265def test_load_with_world_size_diff_by_one(self):
266def create_state_dict(rank, world_size):
267with with_dist(rank=rank, world_size=world_size):
268return {
269"st": create_sharded_tensor(
270rank=rank,
271world_size=world_size,
272shards_per_rank=1,
273shard_size=120 // world_size,
274)
275}
276
277# rank 1 has a 30 bytes shard from [30, 60[
278world4_state_dict = create_state_dict(rank=1, world_size=4)
279world4_metadata = _create_default_local_metadata(world4_state_dict)
280
281# rank 1 has a 40 bytes shard from [40, 80[
282world3_state_dict = create_state_dict(rank=1, world_size=3)
283
284load_plan = create_default_local_load_plan(world3_state_dict, world4_metadata)
285self.assertEqual(2, len(load_plan.items))
286# this is [30, 60] to load [40, 60]
287low_ri = next(
288ri for ri in load_plan.items if ri.dest_offsets == torch.Size([0])
289)
290# this is [60, 90] to load [60, 80]
291high_ri = next(
292ri for ri in load_plan.items if ri.dest_offsets == torch.Size([20])
293)
294
295self.assertEqual(low_ri.storage_index, MetadataIndex("st", [30]))
296self.assertEqual(low_ri.storage_offsets, torch.Size([10]))
297self.assertEqual(low_ri.dest_index, MetadataIndex("st", [40]))
298self.assertEqual(low_ri.dest_offsets, torch.Size([0]))
299self.assertEqual(low_ri.lengths, torch.Size([20]))
300
301self.assertEqual(high_ri.storage_index, MetadataIndex("st", [60]))
302self.assertEqual(high_ri.storage_offsets, torch.Size([0]))
303self.assertEqual(high_ri.dest_index, MetadataIndex("st", [40]))
304self.assertEqual(high_ri.dest_offsets, torch.Size([20]))
305self.assertEqual(high_ri.lengths, torch.Size([20]))
306
307
308class TestPlannerHelpers(TestCase):
309def test_create_read_item_from_chunks(self):
310tensor_md = TensorStorageMetadata(
311properties=TensorProperties.create_from_tensor(torch.empty([16])),
312size=torch.Size([16]),
313chunks=[
314ChunkStorageMetadata(offsets=torch.Size([0]), sizes=torch.Size([8])),
315ChunkStorageMetadata(offsets=torch.Size([8]), sizes=torch.Size([8])),
316],
317)
318
319chunk = ChunkStorageMetadata(offsets=torch.Size([4]), sizes=torch.Size([7]))
320read_items = create_read_items_for_chunk_list("foo", tensor_md, [chunk])
321
322self.assertEqual(2, len(read_items))
323self.assertEqual(MetadataIndex("foo", [4]), read_items[0].dest_index)
324self.assertEqual(torch.Size([0]), read_items[0].dest_offsets)
325
326self.assertEqual(MetadataIndex("foo", [0]), read_items[0].storage_index)
327self.assertEqual(torch.Size([4]), read_items[0].storage_offsets)
328
329self.assertEqual(torch.Size([4]), read_items[0].lengths)
330
331self.assertEqual(MetadataIndex("foo", [4]), read_items[1].dest_index)
332self.assertEqual(torch.Size([4]), read_items[1].dest_offsets)
333
334self.assertEqual(MetadataIndex("foo", [8]), read_items[1].storage_index)
335self.assertEqual(torch.Size([0]), read_items[1].storage_offsets)
336
337self.assertEqual(torch.Size([3]), read_items[1].lengths)
338
339
340if __name__ == "__main__":
341run_tests()
342