pytorch-lightning

Форк
0
169 строк · 7.2 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
import os
15
from copy import deepcopy
16
from datetime import timedelta
17
from unittest import mock
18
from unittest.mock import MagicMock, Mock
19

20
import pytest
21
import torch
22
from lightning.fabric.plugins import DoublePrecision, HalfPrecision, Precision
23
from lightning.fabric.plugins.environments import LightningEnvironment
24
from lightning.fabric.strategies import DDPStrategy
25
from lightning.fabric.strategies.ddp import _DDPBackwardSyncControl
26
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
27
from torch.nn.parallel import DistributedDataParallel
28

29
from tests_fabric.helpers.runif import RunIf
30

31

32
@pytest.mark.parametrize(
33
    ("process_group_backend", "device_str", "expected_process_group_backend"),
34
    [
35
        pytest.param("foo", "cpu", "foo"),
36
        pytest.param("foo", "cuda:0", "foo"),
37
        pytest.param(None, "cuda:0", "nccl"),
38
        pytest.param(None, "cpu", "gloo"),
39
    ],
40
)
41
def test_ddp_process_group_backend(process_group_backend, device_str, expected_process_group_backend):
42
    """Test settings for process group backend."""
43

44
    class MockDDPStrategy(DDPStrategy):
45
        def __init__(self, root_device, process_group_backend):
46
            self._root_device = root_device
47
            super().__init__(process_group_backend=process_group_backend)
48

49
        @property
50
        def root_device(self):
51
            return self._root_device
52

53
    strategy = MockDDPStrategy(process_group_backend=process_group_backend, root_device=torch.device(device_str))
54
    assert strategy._get_process_group_backend() == expected_process_group_backend
55

56

57
def test_ddp_no_backward_sync():
58
    """Test that the backward sync control calls `.no_sync()`, and only on a DDP-wrapped module."""
59
    strategy = DDPStrategy()
60
    assert isinstance(strategy._backward_sync_control, _DDPBackwardSyncControl)
61

62
    with pytest.raises(
63
        TypeError, match="is only possible if the module passed to .* is wrapped in `DistributedDataParallel`"
64
    ), strategy._backward_sync_control.no_backward_sync(Mock()):
65
        pass
66

67
    module = MagicMock(spec=DistributedDataParallel)
68
    with strategy._backward_sync_control.no_backward_sync(module):
69
        pass
70

71
    module.no_sync.assert_called_once()
72

73

74
@mock.patch("lightning.fabric.strategies.ddp.DistributedDataParallel")
75
def test_ddp_extra_kwargs(ddp_mock):
76
    """Test that additional kwargs passed to the DDPStrategy get passed down to the DistributedDataParallel wrapper."""
77
    module = torch.nn.Linear(1, 1)
78
    strategy = DDPStrategy(parallel_devices=[torch.device("cpu"), torch.device("cpu")])
79
    strategy.setup_module(module)
80
    ddp_mock.assert_called_with(module=module, device_ids=None)
81

82
    ddp_mock.reset_mock()
83

84
    strategy = DDPStrategy(parallel_devices=[torch.device("cpu"), torch.device("cpu")], find_unused_parameters=True)
85
    strategy.setup_module(module)
86
    ddp_mock.assert_called_with(module=module, device_ids=None, find_unused_parameters=True)
87

88

89
def test_ddp_module_state_dict():
90
    """Test that the module state dict can be retrieved and loaded without the prefixed wrapper keys from DDP."""
91

92
    class DistributedDataParallelMock(MagicMock):
93
        def __instancecheck__(self, instance):
94
            # to make the strategy's `isinstance(model, DistributedDataParallel)` pass with a mock as class
95
            return True
96

97
    strategy = DDPStrategy(parallel_devices=[torch.device("cpu"), torch.device("cpu")])
98

99
    # Without DDP applied (no setup call)
100
    original_module = torch.nn.Linear(2, 3)
101
    original_state_dict = deepcopy(original_module.state_dict())
102
    retrieved_state_dict = strategy.get_module_state_dict(original_module)
103
    assert retrieved_state_dict.keys() == original_state_dict.keys()
104
    strategy.load_module_state_dict(original_module, retrieved_state_dict)
105

106
    # With DDP applied (setup called)
107
    with mock.patch("lightning.fabric.strategies.ddp.DistributedDataParallel", DistributedDataParallelMock):
108
        wrapped_module = strategy.setup_module(original_module)
109
        retrieved_state_dict = strategy.get_module_state_dict(wrapped_module)
110
    assert retrieved_state_dict.keys() == original_state_dict.keys()
111
    strategy.load_module_state_dict(wrapped_module, retrieved_state_dict)
112
    strategy.load_module_state_dict(wrapped_module, original_state_dict)
113

114

115
@RunIf(min_cuda_gpus=2)
116
@pytest.mark.parametrize(
117
    ("precision", "expected_dtype"),
118
    [
119
        (Precision(), torch.float32),
120
        (HalfPrecision("16-true"), torch.float16),
121
        pytest.param(HalfPrecision("bf16-true"), torch.bfloat16, marks=RunIf(bf16_cuda=True)),
122
        (DoublePrecision(), torch.float64),
123
    ],
124
)
125
@mock.patch.dict(os.environ, {"LOCAL_RANK": "1"})
126
def test_module_init_context(precision, expected_dtype):
127
    """Test that the module under the init-context gets moved to the right device and dtype."""
128
    parallel_devices = [torch.device("cuda", 0), torch.device("cuda", 1)]
129
    expected_device = parallel_devices[1] if _TORCH_GREATER_EQUAL_2_0 else torch.device("cpu")
130

131
    strategy = DDPStrategy(
132
        parallel_devices=parallel_devices, precision=precision, cluster_environment=LightningEnvironment()
133
    )
134
    assert strategy.local_rank == 1
135
    with strategy.module_init_context():
136
        module = torch.nn.Linear(2, 2)
137
    assert module.weight.device == module.bias.device == expected_device
138
    assert module.weight.dtype == module.bias.dtype == expected_dtype
139

140

141
@mock.patch.dict(os.environ, {"LOCAL_RANK": "0"})
142
@mock.patch("lightning.fabric.strategies.ddp.DistributedDataParallel")
143
@mock.patch("torch.cuda.Stream")
144
@mock.patch("torch.cuda.stream")
145
def test_setup_with_cuda_stream(cuda_stream_mock, *_):
146
    model = torch.nn.Linear(2, 2)
147
    strategy = DDPStrategy(parallel_devices=[torch.device("cpu")], cluster_environment=LightningEnvironment())
148
    strategy.setup_module(model)
149
    cuda_stream_mock.assert_not_called()
150

151
    strategy = DDPStrategy(parallel_devices=[torch.device("cuda", 0)], cluster_environment=LightningEnvironment())
152
    strategy.setup_module(model)
153
    cuda_stream_mock.assert_called_once()
154

155

156
@mock.patch("torch.distributed.init_process_group")
157
def test_set_timeout(init_process_group_mock):
158
    """Test that the timeout gets passed to the ``torch.distributed.init_process_group`` function."""
159
    test_timedelta = timedelta(seconds=30)
160
    strategy = DDPStrategy(timeout=test_timedelta, parallel_devices=[torch.device("cpu")])
161
    strategy.cluster_environment = LightningEnvironment()
162
    strategy.accelerator = Mock()
163
    strategy.setup_environment()
164
    process_group_backend = strategy._get_process_group_backend()
165
    global_rank = strategy.cluster_environment.global_rank()
166
    world_size = strategy.cluster_environment.world_size()
167
    init_process_group_mock.assert_called_with(
168
        process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
169
    )
170

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.