pytorch

Форк
0
/
test_e2e_save_and_load.py 
301 строка · 9.4 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import time
4
from enum import auto, Enum
5
from functools import partial
6

7
import torch
8
import torch.distributed as dist
9
import torch.distributed.checkpoint as DCP
10
import torch.distributed.checkpoint.state_dict_saver as saver
11
import torch.nn as nn
12
import torch.nn.functional as F
13
from torch.distributed._tensor.device_mesh import init_device_mesh
14
from torch.distributed.checkpoint.state_dict import (
15
    _patch_model_state_dict,
16
    _patch_optimizer_state_dict,
17
    get_state_dict,
18
)
19
from torch.distributed.distributed_c10d import ReduceOp
20
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
21
from torch.distributed.fsdp.api import ShardingStrategy
22
from torch.distributed.tensor.parallel import (
23
    ColwiseParallel,
24
    parallelize_module,
25
    RowwiseParallel,
26
)
27
from torch.nn.parallel import DistributedDataParallel
28

29
from torch.testing._internal.common_utils import (
30
    instantiate_parametrized_tests,
31
    parametrize,
32
    run_tests,
33
)
34
from torch.testing._internal.distributed._tensor.common_dtensor import (
35
    DTensorTestBase,
36
    skip_if_lt_x_gpu,
37
    with_comms,
38
)
39
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
40
from torch.testing._internal.distributed.common_state_dict import VerifyStateDictMixin
41

42

43
# Simple and boring model
44
class TestDummyModel(torch.nn.Module):
45
    def __init__(self):
46
        super().__init__()
47
        torch.manual_seed(0)
48
        self.net1 = nn.Linear(8, 16)
49
        self.net2 = nn.Linear(16, 32)
50
        self.net3 = nn.Linear(32, 64)
51
        self.net4 = nn.Linear(64, 8)
52

53
    def forward(self, x):
54
        x = F.relu(self.net1(x))
55
        x = F.relu(self.net2(x))
56
        x = F.relu(self.net3(x))
57
        x = F.relu(self.net4(x))
58
        return x
59

60
    def get_input(self):
61
        return torch.rand(8, 8, device="cuda")
62

63

64
class TestStatefulObj:
65
    def __init__(self):
66
        self.data = torch.rand(10, 10, device="cuda")
67

68
    def state_dict(self):
69
        return {"data": self.data}
70

71
    def load_state_dict(self, state_dict):
72
        self.data = state_dict["data"]
73

74
    def __eq__(self, other):
75
        return torch.equal(self.data, other.data)
76

77

78
class ModelType(Enum):
79
    FSDP = auto()
80
    HSDP = auto()
81
    FSDP_TP = auto()
82
    DDP = auto()
83
    NONE = auto()  # no parallelization
84

85

86
def _train(model, optim, train_steps=1):
87
    torch.manual_seed(0)
88
    loss = None
89
    for _ in range(train_steps):
90
        loss = model(model.get_input()).sum()
91
        loss.backward()
92
        optim.step()
93
        optim.zero_grad()
94

95
    return loss
96

97

98
class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
99
    @property
100
    def backend(self):
101
        return "cpu:gloo,cuda:nccl"
102

103
    def _create_model(self, compile, model_type, state_dict_options=None):
104
        dummy_model = TestDummyModel().cuda()
105

106
        assert model_type in ModelType, f"{model_type} is not supported."
107
        if model_type == ModelType.FSDP:
108
            device_mesh = init_device_mesh(self.device_type, (self.world_size,))
109
            model = FSDP(
110
                dummy_model,
111
                device_mesh=device_mesh,
112
                use_orig_params=True,
113
            )
114
        elif model_type == ModelType.HSDP:
115
            device_mesh = init_device_mesh(self.device_type, (2, self.world_size // 2))
116
            model = FSDP(
117
                dummy_model,
118
                device_mesh=device_mesh,
119
                use_orig_params=True,
120
                sharding_strategy=ShardingStrategy.HYBRID_SHARD,
121
            )
122
        elif model_type == ModelType.FSDP_TP:
123
            mesh_2d = init_device_mesh(
124
                self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
125
            )
126
            tp_mesh = mesh_2d["tp"]
127
            dp_mesh = mesh_2d["dp"]
128
            parallelize_plan = {
129
                "net1": ColwiseParallel(),
130
                "net2": RowwiseParallel(),
131
            }
132
            model = parallelize_module(dummy_model, tp_mesh, parallelize_plan)
133
            model = FSDP(model, device_mesh=dp_mesh, use_orig_params=True)
134
        elif model_type == ModelType.DDP:
135
            model = DistributedDataParallel(dummy_model)
136
            model.get_input = partial(TestDummyModel.get_input, model)
137
        else:
138
            model = dummy_model
139

140
        if compile:
141
            # TODO: enable dynamic=True when dynamic shape support is enabled.
142
            # model = torch.compile(model)
143
            model = torch.compile(model, dynamic=False)
144

145
        optim = self._optim(model)
146
        if model_type is not ModelType.NONE:
147
            _patch_model_state_dict(model, options=state_dict_options)
148
            _patch_optimizer_state_dict(
149
                model, optimizers=optim, options=state_dict_options
150
            )
151

152
        return model, optim
153

154
    def _optim(self, model):
155
        return torch.optim.Adam(model.parameters(), lr=0.1)
156

157
    @with_comms
158
    @skip_if_lt_x_gpu(4)
159
    @with_temp_dir
160
    @parametrize("compile", [True, False])
161
    # TODO: Previously PairwiseParallel does not shard properly, passing ModelType.FSDP_TP test where it
162
    # should have failed. Disabling the failed test temporarily to unblock the deprecation of PairwiseParallel.
163
    # @parametrize("model_type", [ModelType.FSDP, ModelType.HSDP, ModelType.FSDP_TP])
164
    @parametrize("model_type", [ModelType.FSDP, ModelType.HSDP, ModelType.DDP])
165
    def test_e2e(self, compile, model_type):
166
        self._run_e2e_test(compile, model_type)
167

168
    @with_comms
169
    @skip_if_lt_x_gpu(4)
170
    @with_temp_dir
171
    def test_e2e_async(self):
172
        self._run_e2e_test(compile=False, model_type=ModelType.FSDP, async_op=True)
173

174
    def _run_e2e_test(self, compile, model_type, async_op=False):
175
        model, optim = self._create_model(compile, ModelType.NONE)
176
        _train(model, optim, train_steps=2)
177

178
        dist_model, dist_optim = self._create_model(compile, model_type)
179
        _train(dist_model, dist_optim, train_steps=2)
180

181
        original_stateful_obj = TestStatefulObj()  # tests arbitrary saving/loading
182
        sd = {
183
            "model": dist_model,
184
            "optimizer": dist_optim,
185
            "s": original_stateful_obj,
186
        }
187

188
        if async_op:
189
            f = saver._async_save(sd, checkpoint_id=self.temp_dir)
190
            t = time.monotonic()
191
            while not f.done():
192
                time.sleep(1)
193
                print(f"still waiting... {time.monotonic() - t}")
194

195
            f.result()
196
        else:
197
            DCP.save(sd, checkpoint_id=self.temp_dir)
198

199
        loaded_stateful_obj = TestStatefulObj()
200
        dist_model, dist_optim = self._create_model(compile, model_type)
201

202
        loaded_stateful_obj = TestStatefulObj()
203
        dist_model, dist_optim = self._create_model(compile, model_type)
204

205
        DCP.load(
206
            state_dict={
207
                "model": dist_model,
208
                "optimizer": dist_optim,
209
                "s": loaded_stateful_obj,
210
            },
211
            checkpoint_id=self.temp_dir,
212
        )
213

214
        self.assertEqual(original_stateful_obj, loaded_stateful_obj)
215

216
        # train one more step on both models
217
        loss = _train(model, optim, train_steps=1)
218
        dist_loss = _train(dist_model, dist_optim, train_steps=1)
219
        self.assertEqual(loss, dist_loss)
220

221
        dist_msd, dist_osd = get_state_dict(dist_model, optimizers=dist_optim)
222
        model_sd, optim_sd = get_state_dict(model, optimizers=optim)
223

224
        self._verify_msd(model_sd, dist_msd)
225
        self._verify_osd_by_load(model, optim, self._optim(model), dist_osd)
226

227
    @with_comms
228
    @with_temp_dir
229
    @skip_if_lt_x_gpu(4)
230
    def test_different_ordered_state_dict_keys(self):
231
        """Tests that the order of keys in the state dict does not matter when loading
232
        If order was not accounted for, the following test would cause a deadlock.
233
        """
234

235
        world_size = self.world_size
236

237
        class Foo:
238
            def state_dict(self):
239
                return {}
240

241
            def load_state_dict(self, state_dict):
242
                tl = [
243
                    torch.ones(2, dtype=torch.int64, device="cuda")
244
                    for _ in range(world_size)
245
                ]
246
                t = (
247
                    torch.arange(2, dtype=torch.int64, device="cuda")
248
                    + 1
249
                    + 2 * dist.get_rank()
250
                )
251
                dist.all_gather(tl, t, async_op=False)
252

253
        class Bar:
254
            def state_dict(self):
255
                return {}
256

257
            def load_state_dict(self, state_dict):
258
                tensor = (
259
                    torch.arange(2, dtype=torch.int64, device="cuda")
260
                    + 1
261
                    + 2 * dist.get_rank()
262
                )
263
                dist.all_reduce(tensor, op=ReduceOp.SUM)
264

265
        if self.rank == 0:
266
            sd = {
267
                "A": Foo(),
268
                "B": Bar(),
269
            }
270
        else:
271
            sd = {
272
                "B": Bar(),
273
                "A": Foo(),
274
            }
275

276
        DCP.save(sd, checkpoint_id=self.temp_dir)
277
        DCP.load(sd, checkpoint_id=self.temp_dir)
278

279
    @with_temp_dir
280
    def test_no_dist(self):
281
        DCP.save({}, checkpoint_id=self.temp_dir, no_dist=True)
282
        DCP.load({}, checkpoint_id=self.temp_dir, no_dist=True)
283

284

285
class TestNoCPU(DTensorTestBase):
286
    @property
287
    def backend(self):
288
        return "nccl"
289

290
    @with_comms
291
    def test_no_cpu(self):
292
        with self.assertRaisesRegex(
293
            AssertionError, r"A CPU backend must be enabled for async save;.*?"
294
        ):
295
            f = saver._async_save({})
296
            f.result()
297

298

299
instantiate_parametrized_tests(TestE2ESaveAndLoad)
300
if __name__ == "__main__":
301
    run_tests()
302

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.