pytorch-lightning

Форк
0
107 строк · 4.7 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
from unittest import mock
15
from unittest.mock import ANY, Mock
16

17
import pytest
18
import torch
19
from lightning.fabric.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher
20

21
from tests_fabric.helpers.runif import RunIf
22

23

24
@RunIf(skip_windows=True)
25
@pytest.mark.parametrize("start_method", ["fork", "forkserver"])
26
def test_interactive_compatible(start_method):
27
    launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
28
    assert launcher.is_interactive_compatible == (start_method == "fork")
29

30

31
@mock.patch("lightning.fabric.strategies.launchers.multiprocessing.mp.get_all_start_methods", return_value=[])
32
def test_forking_on_unsupported_platform(_):
33
    with pytest.raises(ValueError, match="The start method 'fork' is not available on this platform"):
34
        _MultiProcessingLauncher(strategy=Mock(), start_method="fork")
35

36

37
@pytest.mark.parametrize("start_method", ["spawn", pytest.param("fork", marks=RunIf(standalone=True))])
38
@mock.patch("lightning.fabric.strategies.launchers.multiprocessing.mp")
39
@mock.patch("lightning.fabric.strategies.launchers.multiprocessing._check_missing_main_guard")
40
def test_start_method(_, mp_mock, start_method):
41
    mp_mock.get_all_start_methods.return_value = [start_method]
42
    launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
43
    launcher.launch(function=Mock())
44
    mp_mock.get_context.assert_called_with(start_method)
45
    mp_mock.start_processes.assert_called_with(
46
        ANY,
47
        args=ANY,
48
        nprocs=ANY,
49
        start_method=start_method,
50
    )
51

52

53
@pytest.mark.parametrize("start_method", ["spawn", pytest.param("fork", marks=RunIf(standalone=True))])
54
@mock.patch("lightning.fabric.strategies.launchers.multiprocessing.mp")
55
@mock.patch("lightning.fabric.strategies.launchers.multiprocessing._check_missing_main_guard")
56
def test_restore_globals(_, mp_mock, start_method):
57
    """Test that we pass the global state snapshot to the worker function only if we are starting with 'spawn'."""
58
    mp_mock.get_all_start_methods.return_value = [start_method]
59
    launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
60
    launcher.launch(function=Mock())
61
    function_args = mp_mock.start_processes.call_args[1]["args"]
62
    if start_method == "spawn":
63
        assert len(function_args) == 5
64
        assert isinstance(function_args[4], _GlobalStateSnapshot)
65
    else:
66
        assert len(function_args) == 4
67

68

69
@pytest.mark.usefixtures("reset_deterministic_algorithm")
70
def test_global_state_snapshot():
71
    """Test the capture() and restore() methods for the global state snapshot."""
72
    torch.use_deterministic_algorithms(True)
73
    torch.backends.cudnn.benchmark = False
74
    torch.manual_seed(123)
75

76
    # capture the state of globals
77
    snapshot = _GlobalStateSnapshot.capture()
78

79
    # simulate there is a process boundary and flags get reset here
80
    torch.use_deterministic_algorithms(False)
81
    torch.backends.cudnn.benchmark = True
82
    torch.manual_seed(321)
83

84
    # restore the state of globals
85
    snapshot.restore()
86
    assert torch.are_deterministic_algorithms_enabled()
87
    assert not torch.backends.cudnn.benchmark
88
    assert torch.initial_seed() == 123
89

90

91
@pytest.mark.parametrize("start_method", ["fork", "forkserver"])
92
@mock.patch("torch.cuda.is_initialized", return_value=True)
93
@mock.patch("lightning.fabric.strategies.launchers.multiprocessing.mp")
94
def test_check_for_bad_cuda_fork(mp_mock, _, start_method):
95
    mp_mock.get_all_start_methods.return_value = [start_method]
96
    launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
97
    with pytest.raises(RuntimeError, match="Lightning can't create new processes if CUDA is already initialized"):
98
        launcher.launch(function=Mock())
99

100

101
def test_check_for_missing_main_guard():
102
    launcher = _MultiProcessingLauncher(strategy=Mock(), start_method="spawn")
103
    with mock.patch(
104
        "lightning.fabric.strategies.launchers.multiprocessing.mp.current_process",
105
        return_value=Mock(_inheriting=True),  # pretend that main is importing itself
106
    ), pytest.raises(RuntimeError, match="requires that your script guards the main"):
107
        launcher.launch(function=Mock())
108

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.