pytorch-lightning

Форк
0
411 строк · 17.4 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
import json
15
from re import escape
16
from unittest import mock
17
from unittest.mock import ANY, Mock
18

19
import pytest
20
import torch
21
from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator
22
from lightning.fabric.strategies import DeepSpeedStrategy
23
from torch.optim import Optimizer
24

25
from tests_fabric.helpers.runif import RunIf
26

27

28
@pytest.fixture()
29
def deepspeed_config():
30
    return {
31
        "optimizer": {"type": "SGD", "params": {"lr": 3e-5}},
32
        "scheduler": {
33
            "type": "WarmupLR",
34
            "params": {"last_batch_iteration": -1, "warmup_min_lr": 0, "warmup_max_lr": 3e-5, "warmup_num_steps": 100},
35
        },
36
    }
37

38

39
@pytest.fixture()
40
def deepspeed_zero_config(deepspeed_config):
41
    return {**deepspeed_config, "zero_allow_untested_optimizer": True, "zero_optimization": {"stage": 2}}
42

43

44
@RunIf(deepspeed=True)
45
def test_deepspeed_only_compatible_with_cuda():
46
    """Test that the DeepSpeed strategy raises an exception if an invalid accelerator is used."""
47
    strategy = DeepSpeedStrategy(accelerator=CPUAccelerator())
48
    with pytest.raises(RuntimeError, match="The DeepSpeed strategy is only supported on CUDA GPUs"):
49
        strategy.setup_environment()
50

51

52
@RunIf(deepspeed=True)
53
def test_deepspeed_with_invalid_config_path():
54
    """Test to ensure if we pass an invalid config path we throw an exception."""
55
    with pytest.raises(
56
        FileNotFoundError, match="You passed in a path to a DeepSpeed config but the path does not exist"
57
    ):
58
        DeepSpeedStrategy(config="invalid_path.json")
59

60

61
@RunIf(deepspeed=True)
62
def test_deepspeed_with_env_path(tmp_path, monkeypatch, deepspeed_config):
63
    """Test to ensure if we pass an env variable, we load the config from the path."""
64
    config_path = tmp_path / "temp.json"
65
    with open(config_path, "w") as f:
66
        f.write(json.dumps(deepspeed_config))
67
    monkeypatch.setenv("PL_DEEPSPEED_CONFIG_PATH", str(config_path))
68
    strategy = DeepSpeedStrategy()
69
    assert strategy.config == deepspeed_config
70

71

72
@RunIf(deepspeed=True)
73
def test_deepspeed_defaults():
74
    """Ensure that defaults are correctly set as a config for DeepSpeed if no arguments are passed."""
75
    strategy = DeepSpeedStrategy()
76
    assert strategy.config is not None
77
    assert isinstance(strategy.config["zero_optimization"], dict)
78
    assert strategy._backward_sync_control is None
79

80

81
@RunIf(deepspeed=True)
82
def test_deepspeed_custom_activation_checkpointing_params():
83
    """Ensure if we modify the activation checkpointing parameters, the deepspeed config contains these changes."""
84
    ds = DeepSpeedStrategy(
85
        partition_activations=True,
86
        cpu_checkpointing=True,
87
        contiguous_memory_optimization=True,
88
        synchronize_checkpoint_boundary=True,
89
    )
90
    checkpoint_config = ds.config["activation_checkpointing"]
91
    assert checkpoint_config["partition_activations"]
92
    assert checkpoint_config["cpu_checkpointing"]
93
    assert checkpoint_config["contiguous_memory_optimization"]
94
    assert checkpoint_config["synchronize_checkpoint_boundary"]
95

96

97
@RunIf(deepspeed=True)
98
def test_deepspeed_config_zero_offload(deepspeed_zero_config):
99
    """Test the various ways optimizer-offloading can be configured."""
100
    # default config
101
    strategy = DeepSpeedStrategy(config=deepspeed_zero_config)
102
    assert "offload_optimizer" not in strategy.config["zero_optimization"]
103

104
    # default config
105
    strategy = DeepSpeedStrategy()
106
    assert "offload_optimizer" not in strategy.config["zero_optimization"]
107

108
    # default config with `offload_optimizer` argument override
109
    strategy = DeepSpeedStrategy(offload_optimizer=True)
110
    assert strategy.config["zero_optimization"]["offload_optimizer"] == {
111
        "buffer_count": 4,
112
        "device": "cpu",
113
        "nvme_path": "/local_nvme",
114
        "pin_memory": False,
115
    }
116

117
    # externally configured through config
118
    deepspeed_zero_config["zero_optimization"]["offload_optimizer"] = False
119
    strategy = DeepSpeedStrategy(config=deepspeed_zero_config)
120
    assert strategy.config["zero_optimization"]["offload_optimizer"] is False
121

122

123
@RunIf(deepspeed=True)
124
@mock.patch("deepspeed.initialize")
125
def test_deepspeed_setup_module(init_mock):
126
    """Test that the DeepSpeed strategy can set up the model for inference (no optimizer required)."""
127
    model = Mock()
128
    model.parameters.return_value = []
129
    strategy = DeepSpeedStrategy()
130
    strategy.parallel_devices = [torch.device("cuda", 1)]
131
    init_mock.return_value = [Mock()] * 4  # mock to make tuple unpacking work
132

133
    strategy.setup_module(model)
134
    init_mock.assert_called_with(
135
        args=ANY,
136
        config=strategy.config,
137
        model=model,
138
        model_parameters=ANY,
139
        optimizer=None,
140
        dist_init_required=False,
141
    )
142

143

144
@RunIf(deepspeed=True)
145
def test_deepspeed_requires_joint_setup():
146
    """Test that the DeepSpeed strategy does not support setting up model and optimizer independently."""
147
    strategy = DeepSpeedStrategy()
148
    with pytest.raises(
149
        NotImplementedError, match=escape("does not support setting up the module and optimizer(s) independently")
150
    ):
151
        strategy.setup_optimizer(Mock())
152

153

154
@RunIf(deepspeed=True)
155
def test_deepspeed_save_checkpoint_storage_options(tmp_path):
156
    """Test that the DeepSpeed strategy does not accept storage options for saving checkpoints."""
157
    strategy = DeepSpeedStrategy()
158
    with pytest.raises(TypeError, match=escape("DeepSpeedStrategy.save_checkpoint(..., storage_options=...)` is not")):
159
        strategy.save_checkpoint(path=tmp_path, state=Mock(), storage_options=Mock())
160

161

162
@RunIf(deepspeed=True)
163
def test_deepspeed_save_checkpoint_one_deepspeed_engine_required(tmp_path):
164
    """Test that the DeepSpeed strategy can only save one DeepSpeedEngine per checkpoint."""
165
    from deepspeed import DeepSpeedEngine
166

167
    strategy = DeepSpeedStrategy()
168

169
    # missing DeepSpeedEngine
170
    with pytest.raises(ValueError, match="Could not find a DeepSpeed model in the provided checkpoint state."):
171
        strategy.save_checkpoint(path=tmp_path, state={})
172
    with pytest.raises(ValueError, match="Could not find a DeepSpeed model in the provided checkpoint state."):
173
        strategy.save_checkpoint(path=tmp_path, state={"model": torch.nn.Linear(3, 3)})
174

175
    # multiple DeepSpeedEngine
176
    model1 = Mock(spec=torch.nn.Module)
177
    model1.modules.return_value = [Mock(spec=DeepSpeedEngine)]
178
    model2 = Mock(spec=torch.nn.Module)
179
    model2.modules.return_value = [Mock(spec=DeepSpeedEngine)]
180
    with pytest.raises(ValueError, match="Found multiple DeepSpeed engine modules in the given state."):
181
        strategy.save_checkpoint(path=tmp_path, state={"model1": model1, "model2": model2})
182

183

184
@RunIf(deepspeed=True)
185
def test_deepspeed_save_checkpoint_client_state_separation(tmp_path):
186
    """Test that the DeepSpeed engine and optimizer get separated from the client state."""
187
    from deepspeed import DeepSpeedEngine
188

189
    strategy = DeepSpeedStrategy()
190

191
    # Model only
192
    model = Mock(spec=DeepSpeedEngine, optimizer=None)
193
    model.modules.return_value = [model]
194
    strategy.save_checkpoint(path=tmp_path, state={"model": model, "test": "data"})
195
    # the client_state should not contain any deepspeed engine or deepspeed optimizer
196
    model.save_checkpoint.assert_called_with(tmp_path, client_state={"test": "data"}, tag="checkpoint")
197

198
    # Model and optimizer
199
    optimizer = Mock()
200
    model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
201
    model.modules.return_value = [model]
202
    strategy.save_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "test": "data"})
203
    # the client_state should not contain any deepspeed engine or deepspeed optimizer
204
    model.save_checkpoint.assert_called_with(tmp_path, client_state={"test": "data"}, tag="checkpoint")
205

206

207
@RunIf(deepspeed=True)
208
def test_deepspeed_save_checkpoint_warn_colliding_keys(tmp_path):
209
    """Test that the strategy warns if there are keys in the user dict that collide internally with DeepSpeed."""
210
    from deepspeed import DeepSpeedEngine
211

212
    strategy = DeepSpeedStrategy()
213
    optimizer = Mock()
214
    model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
215
    model.modules.return_value = [model]
216
    # `mp_world_size` is an internal key
217
    with pytest.warns(UserWarning, match="Your state has keys that collide with DeepSpeed's internal"):
218
        strategy.save_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "mp_world_size": 2})
219

220

221
@RunIf(deepspeed=True)
222
def test_deepspeed_load_checkpoint_validate_path(tmp_path):
223
    """Test that we validate the checkpoint path for a DeepSpeed checkpoint and give suggestions for user error."""
224
    strategy = DeepSpeedStrategy()
225
    with pytest.raises(FileNotFoundError, match="The provided path is not a valid DeepSpeed checkpoint"):
226
        strategy.load_checkpoint(path=tmp_path, state={"model": Mock()})
227

228
    # User tries to pass the subfolder as the path
229
    checkpoint_path = tmp_path / "checkpoint"
230
    checkpoint_path.mkdir()
231
    with pytest.raises(FileNotFoundError, match=f"Try to load using this parent directory instead: {tmp_path}"):
232
        strategy.load_checkpoint(path=checkpoint_path, state={"model": Mock()})
233

234
    # User tries to pass an individual file inside the checkpoint folder
235
    checkpoint_path = checkpoint_path / "zero_pp_rank_0_mp_rank_00_model_states.pt"
236
    checkpoint_path.touch()
237
    with pytest.raises(FileNotFoundError, match=f"Try to load using this parent directory instead: {tmp_path}"):
238
        strategy.load_checkpoint(path=checkpoint_path, state={"model": Mock()})
239

240

241
@RunIf(deepspeed=True)
242
def test_deepspeed_load_checkpoint_no_state(tmp_path):
243
    """Test that DeepSpeed can't load the full state without access to a model instance from the user."""
244
    strategy = DeepSpeedStrategy()
245
    with pytest.raises(ValueError, match=escape("Got DeepSpeedStrategy.load_checkpoint(..., state=None")):
246
        strategy.load_checkpoint(path=tmp_path, state=None)
247
    with pytest.raises(ValueError, match=escape("Got DeepSpeedStrategy.load_checkpoint(..., state={})")):
248
        strategy.load_checkpoint(path=tmp_path, state={})
249

250

251
@RunIf(deepspeed=True)
252
@mock.patch("lightning.fabric.strategies.deepspeed._is_deepspeed_checkpoint", return_value=True)
253
def test_deepspeed_load_checkpoint_one_deepspeed_engine_required(_, tmp_path):
254
    """Test that the DeepSpeed strategy can only load one DeepSpeedEngine per checkpoint."""
255
    from deepspeed import DeepSpeedEngine
256

257
    strategy = DeepSpeedStrategy()
258

259
    # missing DeepSpeedEngine
260
    with pytest.raises(ValueError, match="Could not find a DeepSpeed model in the provided checkpoint state."):
261
        strategy.load_checkpoint(path=tmp_path, state={"other": "data"})
262
    with pytest.raises(ValueError, match="Could not find a DeepSpeed model in the provided checkpoint state."):
263
        strategy.load_checkpoint(path=tmp_path, state={"model": torch.nn.Linear(3, 3)})
264

265
    # multiple DeepSpeedEngine
266
    model1 = Mock(spec=torch.nn.Module)
267
    model1.modules.return_value = [Mock(spec=DeepSpeedEngine)]
268
    model2 = Mock(spec=torch.nn.Module)
269
    model2.modules.return_value = [Mock(spec=DeepSpeedEngine)]
270
    with pytest.raises(ValueError, match="Found multiple DeepSpeed engine modules in the given state."):
271
        strategy.load_checkpoint(path=tmp_path, state={"model1": model1, "model2": model2})
272

273

274
@RunIf(deepspeed=True)
275
def test_deepspeed_load_checkpoint_client_state_missing(tmp_path):
276
    """Test that the DeepSpeed strategy raises a custom error when client state couldn't be loaded by DeepSpeed."""
277
    from deepspeed import DeepSpeedEngine
278

279
    strategy = DeepSpeedStrategy()
280
    optimizer = Mock()
281
    model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
282
    model.modules.return_value = [model]
283

284
    # If the DeepSpeed engine fails to load the checkpoint file (e.g., file not found), it prints a warning and
285
    # returns None from its function call
286
    model.load_checkpoint.return_value = [None, None]
287

288
    # Check for our custom user error
289
    with pytest.raises(FileNotFoundError, match="The provided path is not a valid DeepSpeed checkpoint"):
290
        strategy.load_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "test": "data"})
291

292

293
@RunIf(deepspeed=True)
294
@mock.patch("lightning.fabric.strategies.deepspeed._is_deepspeed_checkpoint", return_value=True)
295
def test_deepspeed_load_checkpoint_state_updated_with_client_state(_, tmp_path):
296
    """Test that the DeepSpeed strategy properly updates the state variables and returns additional metadata."""
297
    from deepspeed import DeepSpeedEngine
298

299
    strategy = DeepSpeedStrategy()
300
    optimizer = Mock()
301
    model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
302
    model.modules.return_value = [model]
303

304
    # the client state contains the additional user data that was proveded when saving, plus some deepspeed metadata
305
    loaded_client_state = {"user_data": {"iteration": 5}, "deepspeed_metadata": "data"}
306
    model.load_checkpoint.return_value = [None, loaded_client_state]
307

308
    state = {"model": model, "user_data": {"iteration": 0}}
309
    metadata = strategy.load_checkpoint(path=tmp_path, state=state)
310

311
    # the user's state gets updated with the loaded value
312
    assert state == {"model": model, "user_data": {"iteration": 5}}
313
    # additional metadata gets separated from client state
314
    assert metadata == {"deepspeed_metadata": "data"}
315

316

317
@RunIf(deepspeed=True)
318
@pytest.mark.parametrize("optimzer_state_requested", [True, False])
319
@mock.patch("lightning.fabric.strategies.deepspeed._is_deepspeed_checkpoint", return_value=True)
320
def test_deepspeed_load_checkpoint_optimzer_state_requested(_, optimzer_state_requested, tmp_path):
321
    """Test that the DeepSpeed strategy loads the optimizer state only when requested."""
322
    from deepspeed import DeepSpeedEngine
323

324
    strategy = DeepSpeedStrategy()
325
    optimizer = Mock(spec=Optimizer)
326
    model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
327
    model.modules.return_value = [model]
328

329
    # required, otherwise mock cannot be unpacked
330
    model.load_checkpoint.return_value = [None, {}]
331

332
    state = {"model": model}
333
    if optimzer_state_requested:
334
        state["optimizer"] = optimizer
335

336
    strategy.load_checkpoint(path=tmp_path, state=state)
337
    model.load_checkpoint.assert_called_with(
338
        tmp_path,
339
        tag="checkpoint",
340
        load_optimizer_states=optimzer_state_requested,
341
        load_lr_scheduler_states=False,
342
        load_module_strict=True,
343
    )
344

345

346
@RunIf(deepspeed=True)
347
@pytest.mark.parametrize("stage", [1, 2, 3])
348
def test_deepspeed_load_checkpoint_raw_state_dict(stage, tmp_path):
349
    """Test that the `load_checkpoint` can load raw state dict checkpoints too."""
350
    strategy = DeepSpeedStrategy(stage=stage)
351

352
    model = torch.nn.Linear(3, 3)
353
    optimizer = torch.optim.Adam(model.parameters(), lr=1.0)
354
    torch.save(model.state_dict(), tmp_path / "model.ckpt")
355
    torch.save(optimizer.state_dict(), tmp_path / "optimizer.ckpt")
356

357
    new_model = torch.nn.Linear(3, 3)
358
    new_optimizer = torch.optim.Adam(new_model.parameters(), lr=2.0)
359

360
    strategy.load_checkpoint(tmp_path / "model.ckpt", state=new_model, strict=False)
361
    assert torch.equal(new_model.weight, model.weight)
362
    strategy.load_checkpoint(tmp_path / "optimizer.ckpt", state=new_optimizer, strict=False)
363
    assert new_optimizer.state_dict()["param_groups"][0]["lr"] == 1.0
364

365

366
@RunIf(deepspeed=True)
367
def test_errors_grad_clipping():
368
    strategy = DeepSpeedStrategy()
369
    with pytest.raises(
370
        NotImplementedError,
371
        match=(
372
            "DeepSpeed handles gradient clipping automatically within the optimizer. "
373
            "Make sure to set the `gradient_clipping` value in your Config."
374
        ),
375
    ):
376
        strategy.clip_gradients_norm(Mock(), Mock(), Mock(), Mock(), Mock())
377

378
    with pytest.raises(
379
        NotImplementedError,
380
        match=(
381
            "DeepSpeed handles gradient clipping automatically within the optimizer. "
382
            "Make sure to set the `gradient_clipping` value in your Config."
383
        ),
384
    ):
385
        strategy.clip_gradients_value(Mock(), Mock(), Mock())
386

387

388
@RunIf(deepspeed=True, mps=False)
389
def test_deepspeed_save_filter(tmp_path):
390
    strategy = DeepSpeedStrategy()
391
    with pytest.raises(TypeError, match="manages the state serialization internally"):
392
        strategy.save_checkpoint(path=tmp_path, state={}, filter={})
393

394

395
@RunIf(deepspeed=True)
396
@pytest.mark.parametrize("device_indices", [[1], [1, 0], [0, 2], [3, 2, 1]])
397
def test_validate_parallel_devices_indices(device_indices):
398
    """Test that the strategy validates that it doesn't support selecting specific devices by index.
399

400
    DeepSpeed doesn't support it and needs the index to match to the local rank of the process.
401

402
    """
403
    accelerator = Mock(spec=CUDAAccelerator)
404
    strategy = DeepSpeedStrategy(
405
        accelerator=accelerator, parallel_devices=[torch.device("cuda", i) for i in device_indices]
406
    )
407
    with pytest.raises(
408
        RuntimeError, match=escape(f"device indices {device_indices!r} don't match the local rank values of processes")
409
    ):
410
        strategy.setup_environment()
411
    accelerator.setup_device.assert_called_once_with(torch.device("cuda", device_indices[0]))
412

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.