pytorch-lightning
169 строк · 7.2 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import os
15from copy import deepcopy
16from datetime import timedelta
17from unittest import mock
18from unittest.mock import MagicMock, Mock
19
20import pytest
21import torch
22from lightning.fabric.plugins import DoublePrecision, HalfPrecision, Precision
23from lightning.fabric.plugins.environments import LightningEnvironment
24from lightning.fabric.strategies import DDPStrategy
25from lightning.fabric.strategies.ddp import _DDPBackwardSyncControl
26from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
27from torch.nn.parallel import DistributedDataParallel
28
29from tests_fabric.helpers.runif import RunIf
30
31
32@pytest.mark.parametrize(
33("process_group_backend", "device_str", "expected_process_group_backend"),
34[
35pytest.param("foo", "cpu", "foo"),
36pytest.param("foo", "cuda:0", "foo"),
37pytest.param(None, "cuda:0", "nccl"),
38pytest.param(None, "cpu", "gloo"),
39],
40)
41def test_ddp_process_group_backend(process_group_backend, device_str, expected_process_group_backend):
42"""Test settings for process group backend."""
43
44class MockDDPStrategy(DDPStrategy):
45def __init__(self, root_device, process_group_backend):
46self._root_device = root_device
47super().__init__(process_group_backend=process_group_backend)
48
49@property
50def root_device(self):
51return self._root_device
52
53strategy = MockDDPStrategy(process_group_backend=process_group_backend, root_device=torch.device(device_str))
54assert strategy._get_process_group_backend() == expected_process_group_backend
55
56
57def test_ddp_no_backward_sync():
58"""Test that the backward sync control calls `.no_sync()`, and only on a DDP-wrapped module."""
59strategy = DDPStrategy()
60assert isinstance(strategy._backward_sync_control, _DDPBackwardSyncControl)
61
62with pytest.raises(
63TypeError, match="is only possible if the module passed to .* is wrapped in `DistributedDataParallel`"
64), strategy._backward_sync_control.no_backward_sync(Mock()):
65pass
66
67module = MagicMock(spec=DistributedDataParallel)
68with strategy._backward_sync_control.no_backward_sync(module):
69pass
70
71module.no_sync.assert_called_once()
72
73
74@mock.patch("lightning.fabric.strategies.ddp.DistributedDataParallel")
75def test_ddp_extra_kwargs(ddp_mock):
76"""Test that additional kwargs passed to the DDPStrategy get passed down to the DistributedDataParallel wrapper."""
77module = torch.nn.Linear(1, 1)
78strategy = DDPStrategy(parallel_devices=[torch.device("cpu"), torch.device("cpu")])
79strategy.setup_module(module)
80ddp_mock.assert_called_with(module=module, device_ids=None)
81
82ddp_mock.reset_mock()
83
84strategy = DDPStrategy(parallel_devices=[torch.device("cpu"), torch.device("cpu")], find_unused_parameters=True)
85strategy.setup_module(module)
86ddp_mock.assert_called_with(module=module, device_ids=None, find_unused_parameters=True)
87
88
89def test_ddp_module_state_dict():
90"""Test that the module state dict can be retrieved and loaded without the prefixed wrapper keys from DDP."""
91
92class DistributedDataParallelMock(MagicMock):
93def __instancecheck__(self, instance):
94# to make the strategy's `isinstance(model, DistributedDataParallel)` pass with a mock as class
95return True
96
97strategy = DDPStrategy(parallel_devices=[torch.device("cpu"), torch.device("cpu")])
98
99# Without DDP applied (no setup call)
100original_module = torch.nn.Linear(2, 3)
101original_state_dict = deepcopy(original_module.state_dict())
102retrieved_state_dict = strategy.get_module_state_dict(original_module)
103assert retrieved_state_dict.keys() == original_state_dict.keys()
104strategy.load_module_state_dict(original_module, retrieved_state_dict)
105
106# With DDP applied (setup called)
107with mock.patch("lightning.fabric.strategies.ddp.DistributedDataParallel", DistributedDataParallelMock):
108wrapped_module = strategy.setup_module(original_module)
109retrieved_state_dict = strategy.get_module_state_dict(wrapped_module)
110assert retrieved_state_dict.keys() == original_state_dict.keys()
111strategy.load_module_state_dict(wrapped_module, retrieved_state_dict)
112strategy.load_module_state_dict(wrapped_module, original_state_dict)
113
114
115@RunIf(min_cuda_gpus=2)
116@pytest.mark.parametrize(
117("precision", "expected_dtype"),
118[
119(Precision(), torch.float32),
120(HalfPrecision("16-true"), torch.float16),
121pytest.param(HalfPrecision("bf16-true"), torch.bfloat16, marks=RunIf(bf16_cuda=True)),
122(DoublePrecision(), torch.float64),
123],
124)
125@mock.patch.dict(os.environ, {"LOCAL_RANK": "1"})
126def test_module_init_context(precision, expected_dtype):
127"""Test that the module under the init-context gets moved to the right device and dtype."""
128parallel_devices = [torch.device("cuda", 0), torch.device("cuda", 1)]
129expected_device = parallel_devices[1] if _TORCH_GREATER_EQUAL_2_0 else torch.device("cpu")
130
131strategy = DDPStrategy(
132parallel_devices=parallel_devices, precision=precision, cluster_environment=LightningEnvironment()
133)
134assert strategy.local_rank == 1
135with strategy.module_init_context():
136module = torch.nn.Linear(2, 2)
137assert module.weight.device == module.bias.device == expected_device
138assert module.weight.dtype == module.bias.dtype == expected_dtype
139
140
141@mock.patch.dict(os.environ, {"LOCAL_RANK": "0"})
142@mock.patch("lightning.fabric.strategies.ddp.DistributedDataParallel")
143@mock.patch("torch.cuda.Stream")
144@mock.patch("torch.cuda.stream")
145def test_setup_with_cuda_stream(cuda_stream_mock, *_):
146model = torch.nn.Linear(2, 2)
147strategy = DDPStrategy(parallel_devices=[torch.device("cpu")], cluster_environment=LightningEnvironment())
148strategy.setup_module(model)
149cuda_stream_mock.assert_not_called()
150
151strategy = DDPStrategy(parallel_devices=[torch.device("cuda", 0)], cluster_environment=LightningEnvironment())
152strategy.setup_module(model)
153cuda_stream_mock.assert_called_once()
154
155
156@mock.patch("torch.distributed.init_process_group")
157def test_set_timeout(init_process_group_mock):
158"""Test that the timeout gets passed to the ``torch.distributed.init_process_group`` function."""
159test_timedelta = timedelta(seconds=30)
160strategy = DDPStrategy(timeout=test_timedelta, parallel_devices=[torch.device("cpu")])
161strategy.cluster_environment = LightningEnvironment()
162strategy.accelerator = Mock()
163strategy.setup_environment()
164process_group_backend = strategy._get_process_group_backend()
165global_rank = strategy.cluster_environment.global_rank()
166world_size = strategy.cluster_environment.world_size()
167init_process_group_mock.assert_called_with(
168process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
169)
170