pytorch-lightning

Форк
0
93 строки · 3.4 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
from unittest import mock
15
from unittest.mock import MagicMock, Mock
16

17
import pytest
18
import torch
19
from lightning.fabric import Fabric
20
from lightning.fabric.strategies import DataParallelStrategy
21

22
from tests_fabric.helpers.runif import RunIf
23
from tests_fabric.strategies.test_single_device import _run_test_clip_gradients
24

25

26
def test_data_parallel_root_device():
27
    strategy = DataParallelStrategy()
28
    strategy.parallel_devices = [torch.device("cuda", 2), torch.device("cuda", 0), torch.device("cuda", 1)]
29
    assert strategy.root_device == torch.device("cuda", 2)
30

31

32
def test_data_parallel_ranks():
33
    strategy = DataParallelStrategy()
34
    assert strategy.world_size == 1
35
    assert strategy.local_rank == 0
36
    assert strategy.global_rank == 0
37
    assert strategy.is_global_zero
38

39

40
@mock.patch("lightning.fabric.strategies.dp.DataParallel")
41
def test_data_parallel_setup_module(data_parallel_mock):
42
    strategy = DataParallelStrategy()
43
    strategy.parallel_devices = [0, 2, 1]
44
    module = torch.nn.Linear(2, 2)
45
    wrapped_module = strategy.setup_module(module)
46
    assert wrapped_module == data_parallel_mock(module=module, device_ids=[0, 2, 1])
47

48

49
def test_data_parallel_module_to_device():
50
    strategy = DataParallelStrategy()
51
    strategy.parallel_devices = [torch.device("cuda", 2)]
52
    module = Mock()
53
    strategy.module_to_device(module)
54
    module.to.assert_called_with(torch.device("cuda", 2))
55

56

57
def test_dp_module_state_dict():
58
    """Test that the module state dict gets retrieved without the prefixed wrapper keys from DP."""
59

60
    class DataParallelMock(MagicMock):
61
        def __instancecheck__(self, instance):
62
            # to make the strategy's `isinstance(model, DataParallel)` pass with a mock as class
63
            return True
64

65
    strategy = DataParallelStrategy(parallel_devices=[torch.device("cpu"), torch.device("cpu")])
66

67
    # Without DP applied (no setup call)
68
    original_module = torch.nn.Linear(2, 3)
69
    assert strategy.get_module_state_dict(original_module).keys() == original_module.state_dict().keys()
70

71
    # With DP applied (setup called)
72
    with mock.patch("lightning.fabric.strategies.dp.DataParallel", DataParallelMock):
73
        wrapped_module = strategy.setup_module(original_module)
74
        assert strategy.get_module_state_dict(wrapped_module).keys() == original_module.state_dict().keys()
75

76

77
@pytest.mark.parametrize(
78
    "precision",
79
    [
80
        "32-true",
81
        "16-mixed",
82
        pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True)),
83
    ],
84
)
85
@pytest.mark.parametrize("clip_type", ["norm", "val"])
86
@RunIf(min_cuda_gpus=2)
87
def test_clip_gradients(clip_type, precision):
88
    if clip_type == "norm" and precision == "16-mixed":
89
        pytest.skip(reason="Clipping by norm with 16-mixed is numerically unstable.")
90

91
    fabric = Fabric(accelerator="cuda", devices=2, precision=precision, strategy="dp")
92
    fabric.launch()
93
    _run_test_clip_gradients(fabric=fabric, clip_type=clip_type)
94

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.