pytorch
301 строка · 9.4 Кб
1# Owner(s): ["oncall: distributed"]
2
3import time
4from enum import auto, Enum
5from functools import partial
6
7import torch
8import torch.distributed as dist
9import torch.distributed.checkpoint as DCP
10import torch.distributed.checkpoint.state_dict_saver as saver
11import torch.nn as nn
12import torch.nn.functional as F
13from torch.distributed._tensor.device_mesh import init_device_mesh
14from torch.distributed.checkpoint.state_dict import (
15_patch_model_state_dict,
16_patch_optimizer_state_dict,
17get_state_dict,
18)
19from torch.distributed.distributed_c10d import ReduceOp
20from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
21from torch.distributed.fsdp.api import ShardingStrategy
22from torch.distributed.tensor.parallel import (
23ColwiseParallel,
24parallelize_module,
25RowwiseParallel,
26)
27from torch.nn.parallel import DistributedDataParallel
28
29from torch.testing._internal.common_utils import (
30instantiate_parametrized_tests,
31parametrize,
32run_tests,
33)
34from torch.testing._internal.distributed._tensor.common_dtensor import (
35DTensorTestBase,
36skip_if_lt_x_gpu,
37with_comms,
38)
39from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
40from torch.testing._internal.distributed.common_state_dict import VerifyStateDictMixin
41
42
43# Simple and boring model
44class TestDummyModel(torch.nn.Module):
45def __init__(self):
46super().__init__()
47torch.manual_seed(0)
48self.net1 = nn.Linear(8, 16)
49self.net2 = nn.Linear(16, 32)
50self.net3 = nn.Linear(32, 64)
51self.net4 = nn.Linear(64, 8)
52
53def forward(self, x):
54x = F.relu(self.net1(x))
55x = F.relu(self.net2(x))
56x = F.relu(self.net3(x))
57x = F.relu(self.net4(x))
58return x
59
60def get_input(self):
61return torch.rand(8, 8, device="cuda")
62
63
64class TestStatefulObj:
65def __init__(self):
66self.data = torch.rand(10, 10, device="cuda")
67
68def state_dict(self):
69return {"data": self.data}
70
71def load_state_dict(self, state_dict):
72self.data = state_dict["data"]
73
74def __eq__(self, other):
75return torch.equal(self.data, other.data)
76
77
78class ModelType(Enum):
79FSDP = auto()
80HSDP = auto()
81FSDP_TP = auto()
82DDP = auto()
83NONE = auto() # no parallelization
84
85
86def _train(model, optim, train_steps=1):
87torch.manual_seed(0)
88loss = None
89for _ in range(train_steps):
90loss = model(model.get_input()).sum()
91loss.backward()
92optim.step()
93optim.zero_grad()
94
95return loss
96
97
98class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
99@property
100def backend(self):
101return "cpu:gloo,cuda:nccl"
102
103def _create_model(self, compile, model_type, state_dict_options=None):
104dummy_model = TestDummyModel().cuda()
105
106assert model_type in ModelType, f"{model_type} is not supported."
107if model_type == ModelType.FSDP:
108device_mesh = init_device_mesh(self.device_type, (self.world_size,))
109model = FSDP(
110dummy_model,
111device_mesh=device_mesh,
112use_orig_params=True,
113)
114elif model_type == ModelType.HSDP:
115device_mesh = init_device_mesh(self.device_type, (2, self.world_size // 2))
116model = FSDP(
117dummy_model,
118device_mesh=device_mesh,
119use_orig_params=True,
120sharding_strategy=ShardingStrategy.HYBRID_SHARD,
121)
122elif model_type == ModelType.FSDP_TP:
123mesh_2d = init_device_mesh(
124self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
125)
126tp_mesh = mesh_2d["tp"]
127dp_mesh = mesh_2d["dp"]
128parallelize_plan = {
129"net1": ColwiseParallel(),
130"net2": RowwiseParallel(),
131}
132model = parallelize_module(dummy_model, tp_mesh, parallelize_plan)
133model = FSDP(model, device_mesh=dp_mesh, use_orig_params=True)
134elif model_type == ModelType.DDP:
135model = DistributedDataParallel(dummy_model)
136model.get_input = partial(TestDummyModel.get_input, model)
137else:
138model = dummy_model
139
140if compile:
141# TODO: enable dynamic=True when dynamic shape support is enabled.
142# model = torch.compile(model)
143model = torch.compile(model, dynamic=False)
144
145optim = self._optim(model)
146if model_type is not ModelType.NONE:
147_patch_model_state_dict(model, options=state_dict_options)
148_patch_optimizer_state_dict(
149model, optimizers=optim, options=state_dict_options
150)
151
152return model, optim
153
154def _optim(self, model):
155return torch.optim.Adam(model.parameters(), lr=0.1)
156
157@with_comms
158@skip_if_lt_x_gpu(4)
159@with_temp_dir
160@parametrize("compile", [True, False])
161# TODO: Previously PairwiseParallel does not shard properly, passing ModelType.FSDP_TP test where it
162# should have failed. Disabling the failed test temporarily to unblock the deprecation of PairwiseParallel.
163# @parametrize("model_type", [ModelType.FSDP, ModelType.HSDP, ModelType.FSDP_TP])
164@parametrize("model_type", [ModelType.FSDP, ModelType.HSDP, ModelType.DDP])
165def test_e2e(self, compile, model_type):
166self._run_e2e_test(compile, model_type)
167
168@with_comms
169@skip_if_lt_x_gpu(4)
170@with_temp_dir
171def test_e2e_async(self):
172self._run_e2e_test(compile=False, model_type=ModelType.FSDP, async_op=True)
173
174def _run_e2e_test(self, compile, model_type, async_op=False):
175model, optim = self._create_model(compile, ModelType.NONE)
176_train(model, optim, train_steps=2)
177
178dist_model, dist_optim = self._create_model(compile, model_type)
179_train(dist_model, dist_optim, train_steps=2)
180
181original_stateful_obj = TestStatefulObj() # tests arbitrary saving/loading
182sd = {
183"model": dist_model,
184"optimizer": dist_optim,
185"s": original_stateful_obj,
186}
187
188if async_op:
189f = saver._async_save(sd, checkpoint_id=self.temp_dir)
190t = time.monotonic()
191while not f.done():
192time.sleep(1)
193print(f"still waiting... {time.monotonic() - t}")
194
195f.result()
196else:
197DCP.save(sd, checkpoint_id=self.temp_dir)
198
199loaded_stateful_obj = TestStatefulObj()
200dist_model, dist_optim = self._create_model(compile, model_type)
201
202loaded_stateful_obj = TestStatefulObj()
203dist_model, dist_optim = self._create_model(compile, model_type)
204
205DCP.load(
206state_dict={
207"model": dist_model,
208"optimizer": dist_optim,
209"s": loaded_stateful_obj,
210},
211checkpoint_id=self.temp_dir,
212)
213
214self.assertEqual(original_stateful_obj, loaded_stateful_obj)
215
216# train one more step on both models
217loss = _train(model, optim, train_steps=1)
218dist_loss = _train(dist_model, dist_optim, train_steps=1)
219self.assertEqual(loss, dist_loss)
220
221dist_msd, dist_osd = get_state_dict(dist_model, optimizers=dist_optim)
222model_sd, optim_sd = get_state_dict(model, optimizers=optim)
223
224self._verify_msd(model_sd, dist_msd)
225self._verify_osd_by_load(model, optim, self._optim(model), dist_osd)
226
227@with_comms
228@with_temp_dir
229@skip_if_lt_x_gpu(4)
230def test_different_ordered_state_dict_keys(self):
231"""Tests that the order of keys in the state dict does not matter when loading
232If order was not accounted for, the following test would cause a deadlock.
233"""
234
235world_size = self.world_size
236
237class Foo:
238def state_dict(self):
239return {}
240
241def load_state_dict(self, state_dict):
242tl = [
243torch.ones(2, dtype=torch.int64, device="cuda")
244for _ in range(world_size)
245]
246t = (
247torch.arange(2, dtype=torch.int64, device="cuda")
248+ 1
249+ 2 * dist.get_rank()
250)
251dist.all_gather(tl, t, async_op=False)
252
253class Bar:
254def state_dict(self):
255return {}
256
257def load_state_dict(self, state_dict):
258tensor = (
259torch.arange(2, dtype=torch.int64, device="cuda")
260+ 1
261+ 2 * dist.get_rank()
262)
263dist.all_reduce(tensor, op=ReduceOp.SUM)
264
265if self.rank == 0:
266sd = {
267"A": Foo(),
268"B": Bar(),
269}
270else:
271sd = {
272"B": Bar(),
273"A": Foo(),
274}
275
276DCP.save(sd, checkpoint_id=self.temp_dir)
277DCP.load(sd, checkpoint_id=self.temp_dir)
278
279@with_temp_dir
280def test_no_dist(self):
281DCP.save({}, checkpoint_id=self.temp_dir, no_dist=True)
282DCP.load({}, checkpoint_id=self.temp_dir, no_dist=True)
283
284
285class TestNoCPU(DTensorTestBase):
286@property
287def backend(self):
288return "nccl"
289
290@with_comms
291def test_no_cpu(self):
292with self.assertRaisesRegex(
293AssertionError, r"A CPU backend must be enabled for async save;.*?"
294):
295f = saver._async_save({})
296f.result()
297
298
299instantiate_parametrized_tests(TestE2ESaveAndLoad)
300if __name__ == "__main__":
301run_tests()
302