pytorch-lightning
93 строки · 3.4 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14from unittest import mock
15from unittest.mock import MagicMock, Mock
16
17import pytest
18import torch
19from lightning.fabric import Fabric
20from lightning.fabric.strategies import DataParallelStrategy
21
22from tests_fabric.helpers.runif import RunIf
23from tests_fabric.strategies.test_single_device import _run_test_clip_gradients
24
25
26def test_data_parallel_root_device():
27strategy = DataParallelStrategy()
28strategy.parallel_devices = [torch.device("cuda", 2), torch.device("cuda", 0), torch.device("cuda", 1)]
29assert strategy.root_device == torch.device("cuda", 2)
30
31
32def test_data_parallel_ranks():
33strategy = DataParallelStrategy()
34assert strategy.world_size == 1
35assert strategy.local_rank == 0
36assert strategy.global_rank == 0
37assert strategy.is_global_zero
38
39
40@mock.patch("lightning.fabric.strategies.dp.DataParallel")
41def test_data_parallel_setup_module(data_parallel_mock):
42strategy = DataParallelStrategy()
43strategy.parallel_devices = [0, 2, 1]
44module = torch.nn.Linear(2, 2)
45wrapped_module = strategy.setup_module(module)
46assert wrapped_module == data_parallel_mock(module=module, device_ids=[0, 2, 1])
47
48
49def test_data_parallel_module_to_device():
50strategy = DataParallelStrategy()
51strategy.parallel_devices = [torch.device("cuda", 2)]
52module = Mock()
53strategy.module_to_device(module)
54module.to.assert_called_with(torch.device("cuda", 2))
55
56
57def test_dp_module_state_dict():
58"""Test that the module state dict gets retrieved without the prefixed wrapper keys from DP."""
59
60class DataParallelMock(MagicMock):
61def __instancecheck__(self, instance):
62# to make the strategy's `isinstance(model, DataParallel)` pass with a mock as class
63return True
64
65strategy = DataParallelStrategy(parallel_devices=[torch.device("cpu"), torch.device("cpu")])
66
67# Without DP applied (no setup call)
68original_module = torch.nn.Linear(2, 3)
69assert strategy.get_module_state_dict(original_module).keys() == original_module.state_dict().keys()
70
71# With DP applied (setup called)
72with mock.patch("lightning.fabric.strategies.dp.DataParallel", DataParallelMock):
73wrapped_module = strategy.setup_module(original_module)
74assert strategy.get_module_state_dict(wrapped_module).keys() == original_module.state_dict().keys()
75
76
77@pytest.mark.parametrize(
78"precision",
79[
80"32-true",
81"16-mixed",
82pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True)),
83],
84)
85@pytest.mark.parametrize("clip_type", ["norm", "val"])
86@RunIf(min_cuda_gpus=2)
87def test_clip_gradients(clip_type, precision):
88if clip_type == "norm" and precision == "16-mixed":
89pytest.skip(reason="Clipping by norm with 16-mixed is numerically unstable.")
90
91fabric = Fabric(accelerator="cuda", devices=2, precision=precision, strategy="dp")
92fabric.launch()
93_run_test_clip_gradients(fabric=fabric, clip_type=clip_type)
94