pytorch

Форк
0
/
test_hsdp_dtensor_state_dict.py 
323 строки · 12.6 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import io
4
from copy import deepcopy
5

6
import torch
7
import torch.distributed as dist
8
import torch.nn as nn
9
from torch.distributed._shard.sharded_tensor import ShardedTensor
10

11
from torch.distributed._tensor import DTensor, Replicate, Shard
12
from torch.distributed.device_mesh import init_device_mesh
13
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
14
from torch.distributed.fsdp.api import (
15
    ShardedOptimStateDictConfig,
16
    ShardedStateDictConfig,
17
    ShardingStrategy,
18
    StateDictType,
19
)
20
from torch.testing._internal.common_utils import (
21
    instantiate_parametrized_tests,
22
    parametrize,
23
    run_tests,
24
)
25

26
from torch.testing._internal.distributed._tensor.common_dtensor import (
27
    DTensorTestBase,
28
    skip_if_lt_x_gpu,
29
    with_comms,
30
)
31

32

33
# Simple and boring model to test interface and some corner cases that do not
34
# require complicated wrapping strategy.
35
class DenseModel(torch.nn.Module):
36
    def __init__(self):
37
        super().__init__()
38
        torch.manual_seed(0)
39
        self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())
40
        self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU())
41
        self.net3 = nn.Sequential(nn.Linear(32, 64), nn.ReLU())
42
        self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8))
43

44
    def forward(self, x):
45
        return self.net4(self.net3(self.net2(self.net1(x))))
46

47
    def get_input(self):
48
        return torch.rand(4, 8, device="cuda")
49

50

51
# TODO: Consolidate DeviceMesh based FSDP and HSDP test cases.
52
class TestHSDPWithDeviceMeshAndDTensor(DTensorTestBase):
53
    def _create_model(self, device_mesh=None):
54
        if device_mesh:
55
            model = FSDP(
56
                DenseModel().cuda(),
57
                device_mesh=device_mesh,
58
                sharding_strategy=ShardingStrategy.HYBRID_SHARD,
59
            )
60
        else:
61
            mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))
62
            intra_node_pg = mesh_2d.get_group(mesh_dim=1)
63
            inter_node_pg = mesh_2d.get_group(mesh_dim=0)
64
            model = FSDP(
65
                DenseModel().cuda(),
66
                process_group=(intra_node_pg, inter_node_pg),
67
                sharding_strategy=ShardingStrategy.HYBRID_SHARD,
68
            )
69

70
        optim = torch.optim.Adam(model.parameters(), lr=0.1)
71
        model(model.get_input()).sum().backward()
72
        optim.step()
73

74
        return model, optim
75

76
    @with_comms
77
    @skip_if_lt_x_gpu(4)
78
    def test_hsdp_init_with_device_mesh(self):
79
        mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))
80
        model, optim = self._create_model(mesh_2d)
81

82
        FSDP.set_state_dict_type(
83
            model,
84
            StateDictType.SHARDED_STATE_DICT,
85
        )
86
        state_dict = model.state_dict()
87
        optim_state_dict = FSDP.optim_state_dict(model, optim)
88

89
        for v in state_dict.values():
90
            self.assertEqual(type(v), DTensor)
91
            self.assertEqual(len(v.placements), 2)
92
            self.assertEqual(v.placements, (Replicate(), Shard(0)))
93
            self.assertEqual(v.device_mesh, mesh_2d)
94

95
        for state in optim_state_dict["state"].values():
96
            for k, v in state.items():
97
                if k != "step":
98
                    self.assertEqual(type(v), DTensor)
99
                    self.assertEqual(len(v.placements), 2)
100
                    self.assertEqual(v.placements, (Replicate(), Shard(0)))
101
                    self.assertEqual(v.device_mesh, mesh_2d)
102

103
        state_dict_type = model.get_state_dict_type(model)
104
        # If device_mesh is used when initializing FSDP, the field _use_dtensor will
105
        # automatically be set to True.
106
        self.assertEqual(state_dict_type.state_dict_config._use_dtensor, True)
107
        self.assertEqual(state_dict_type.optim_state_dict_config._use_dtensor, True)
108

109
    @with_comms
110
    @skip_if_lt_x_gpu(4)
111
    @parametrize("offload_to_cpu", [True, False])
112
    def test_dtensor_sharded_tensor_state_dict_identical(self, offload_to_cpu):
113
        mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))
114
        model, optim = self._create_model(mesh_2d)
115

116
        FSDP.set_state_dict_type(
117
            model,
118
            StateDictType.SHARDED_STATE_DICT,
119
            state_dict_config=ShardedStateDictConfig(offload_to_cpu=offload_to_cpu),
120
            optim_state_dict_config=ShardedOptimStateDictConfig(
121
                offload_to_cpu=offload_to_cpu
122
            ),
123
        )
124
        dtensor_sd = model.state_dict()
125
        dtensor_osd = FSDP.optim_state_dict(model, optim)
126

127
        ref_model, ref_optim = self._create_model()
128
        FSDP.set_state_dict_type(
129
            ref_model,
130
            StateDictType.SHARDED_STATE_DICT,
131
            state_dict_config=ShardedStateDictConfig(offload_to_cpu=offload_to_cpu),
132
            optim_state_dict_config=ShardedOptimStateDictConfig(
133
                offload_to_cpu=offload_to_cpu
134
            ),
135
        )
136
        sharded_tensor_sd = ref_model.state_dict()
137
        sharded_tensor_osd = FSDP.optim_state_dict(ref_model, ref_optim)
138

139
        # Check dtensor and sharded_tensor model state dict values are identical
140
        for dtensor_sd_item, sharded_tensor_sd_item in zip(
141
            dtensor_sd.items(), sharded_tensor_sd.items()
142
        ):
143
            k1, v1 = dtensor_sd_item
144
            k2, v2 = sharded_tensor_sd_item
145
            self.assertEqual(k1, k2)
146

147
            self.assertEqual(type(v1), DTensor)
148
            self.assertEqual(type(v2), ShardedTensor)
149
            # check whether local_tensor are the same
150
            self.assertEqual(v1.to_local(), v2.local_tensor())
151
            # check whether device are the same
152
            self.assertEqual(v1.to_local().device, v2.local_tensor().device)
153

154
        # Check dtensor and sharde_tensor optim state dict values are identical
155
        for dtensor_osd_state, sharded_tensor_osd_state in zip(
156
            dtensor_osd["state"].items(), sharded_tensor_osd["state"].items()
157
        ):
158
            # check FQN are the same
159
            self.assertEqual(dtensor_osd_state[0], sharded_tensor_osd_state[0])
160
            for dtensor_hyper_param, sharded_tensor_hyper_param in zip(
161
                dtensor_osd_state[1].items(),
162
                sharded_tensor_osd_state[1].items(),
163
            ):
164
                k1, v1 = dtensor_hyper_param
165
                k2, v2 = sharded_tensor_hyper_param
166
                self.assertEqual(k1, k2)
167

168
                if k1 != "step":
169
                    self.assertEqual(type(v1), DTensor)
170
                    self.assertEqual(type(v2), ShardedTensor)
171
                    # check whether local_tensor are the same
172
                    self.assertEqual(v1.to_local(), v2.local_tensor())
173
                    # check whether device are the same
174
                    self.assertEqual(v1.to_local().device, v2.local_tensor().device)
175
                else:
176
                    self.assertEqual(v1, v2)
177

178
    @with_comms
179
    @skip_if_lt_x_gpu(4)
180
    @parametrize("offload_to_cpu", [True, False])
181
    def test_dtensor_sharded_optim_load_state_dict(self, offload_to_cpu):
182
        mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))
183
        model, optim = self._create_model(mesh_2d)
184

185
        FSDP.set_state_dict_type(
186
            model,
187
            StateDictType.SHARDED_STATE_DICT,
188
            optim_state_dict_config=ShardedOptimStateDictConfig(
189
                offload_to_cpu=offload_to_cpu
190
            ),
191
        )
192

193
        checkpoint = io.BytesIO()
194
        torch.save(FSDP.optim_state_dict(model, optim), checkpoint)
195
        # Deepcopy to save current optim_state_dict to compare with the optim_state_dict loaded back below.
196
        ref_optim_state_dict = deepcopy(FSDP.optim_state_dict(model, optim))
197

198
        # Update the parameters so FSDP.optim_state_dict() will be different from ref_optim_state_dict.
199
        model(model.get_input()).sum().backward()
200
        optim.step()
201

202
        # Load ref_optim_state_dict back.
203
        checkpoint.seek(0)
204
        load_ref_optim_state_dict = torch.load(checkpoint)
205
        optim.load_state_dict(
206
            FSDP.optim_state_dict_to_load(model, optim, load_ref_optim_state_dict)
207
        )
208
        new_optim_state_dict = FSDP.optim_state_dict(model, optim)
209

210
        # Check whether new_optim_state_dict is the same as ref_optim_state_dict.
211
        for new_optim_state_dict_item, ref_optim_state_dict_item in zip(
212
            new_optim_state_dict["state"].items(),
213
            ref_optim_state_dict["state"].items(),
214
        ):
215
            # check FQN are the same
216
            self.assertEqual(new_optim_state_dict_item[0], ref_optim_state_dict_item[0])
217
            for new_optim_hyper_param, ref_optim_hyper_param in zip(
218
                new_optim_state_dict_item[1].items(),
219
                ref_optim_state_dict_item[1].items(),
220
            ):
221
                k1, v1 = new_optim_hyper_param
222
                k2, v2 = ref_optim_hyper_param
223
                # check whether keys are the same
224
                self.assertEqual(k1, k2)
225
                # check whether DTensor are the same
226
                self.assertEqual(v1, v2)
227

228
                if k1 != "step":
229
                    self.assertEqual(type(v1), DTensor)
230
                    self.assertEqual(type(v2), DTensor)
231

232
    @with_comms
233
    @skip_if_lt_x_gpu(4)
234
    @parametrize("offload_to_cpu", [True, False])
235
    def test_dtensor_sharded_model_load_state_dict(self, offload_to_cpu):
236
        mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))
237
        model, optim = self._create_model(mesh_2d)
238

239
        FSDP.set_state_dict_type(
240
            model,
241
            StateDictType.SHARDED_STATE_DICT,
242
            state_dict_config=ShardedStateDictConfig(offload_to_cpu=offload_to_cpu),
243
        )
244

245
        checkpoint = io.BytesIO()
246
        torch.save(model.state_dict(), checkpoint)
247
        # Deepcopy to save current state_dict to compare with the state_dict loaded back below.
248
        ref_state_dict = deepcopy(model.state_dict())
249

250
        # Update the parameters so model.state_dict() will be different from ref_dtensor_sd.
251
        model(model.get_input()).sum().backward()
252
        optim.step()
253

254
        # Load ref_state_dict back.
255
        checkpoint.seek(0)
256
        load_ref_state_dict = torch.load(checkpoint)
257
        model.load_state_dict(load_ref_state_dict)
258
        new_state_dict = model.state_dict()
259

260
        # Check whether new_state_dict is the same as ref_state_dict.
261
        for (k1, v1), (k2, v2) in zip(ref_state_dict.items(), new_state_dict.items()):
262
            # check whether fqn are the same
263
            self.assertEqual(k1, k2)
264

265
            self.assertEqual(type(v1), DTensor)
266
            self.assertEqual(type(v2), DTensor)
267
            # check whether DTensor are the same
268
            self.assertEqual(v1, v2)
269

270
    @with_comms
271
    @skip_if_lt_x_gpu(4)
272
    def test_root_module_is_not_FSDP(self):
273
        class FakeMPModel(torch.nn.Module):
274
            def __init__(self, device_mesh):
275
                super().__init__()
276
                torch.manual_seed(0)
277
                self.dense = FSDP(
278
                    DenseModel().cuda(),
279
                    use_orig_params=True,
280
                    sharding_strategy=ShardingStrategy.HYBRID_SHARD,
281
                    device_mesh=device_mesh,
282
                )
283
                if dist.get_rank() == 0:
284
                    self.sparse0 = nn.Sequential(nn.Linear(8, 8), nn.ReLU())
285
                else:
286
                    self.sparse1 = nn.Sequential(nn.Linear(8, 8), nn.ReLU())
287

288
            def forward(self, x):
289
                if dist.get_rank() == 0:
290
                    sparse = self.sparse0(x)
291
                else:
292
                    sparse = self.sparse1(x)
293
                dist.all_reduce(sparse)
294
                return self.dense(sparse)
295

296
        mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))
297
        model = FakeMPModel(device_mesh=mesh_2d).cuda()
298
        optim = torch.optim.Adam(model.parameters(), lr=1e-2)
299

300
        batch = torch.rand(5, 8, device=torch.device("cuda"))
301
        model(batch).sum().backward()
302
        optim.step()
303
        osd = optim.state_dict()
304

305
        with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
306
            osd = FSDP.optim_state_dict(model, optim, osd)
307

308
        for param, state in osd["state"].items():
309
            if "dense" in param:
310
                self.assertIsInstance(state["exp_avg"], DTensor)
311
                self.assertIsInstance(state["exp_avg_sq"], DTensor)
312
                self.assertEqual(state["exp_avg"].placements, (Replicate(), Shard(0)))
313
                self.assertEqual(
314
                    state["exp_avg_sq"].placements, (Replicate(), Shard(0))
315
                )
316
            else:
317
                self.assertIsInstance(state["exp_avg"], torch.Tensor)
318
                self.assertIsInstance(state["exp_avg_sq"], torch.Tensor)
319

320

321
instantiate_parametrized_tests(TestHSDPWithDeviceMeshAndDTensor)
322
if __name__ == "__main__":
323
    run_tests()
324

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.