pytorch-lightning

Форк
0
183 строки · 8.2 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
import os
15
import signal
16
import sys
17
from unittest import mock
18
from unittest.mock import ANY, Mock
19

20
import lightning.fabric
21
import pytest
22
from lightning.fabric.strategies.launchers.subprocess_script import (
23
    _HYDRA_AVAILABLE,
24
    _ChildProcessObserver,
25
    _SubprocessScriptLauncher,
26
)
27

28

29
def test_subprocess_script_launcher_interactive_compatible():
30
    launcher = _SubprocessScriptLauncher(Mock(), num_processes=2, num_nodes=1)
31
    assert not launcher.is_interactive_compatible
32

33

34
@mock.patch("lightning.fabric.strategies.launchers.subprocess_script.subprocess.Popen")
35
@mock.patch("lightning.fabric.strategies.launchers.subprocess_script._ChildProcessObserver")
36
def test_subprocess_script_launcher_can_launch(*_):
37
    cluster_env = Mock()
38
    cluster_env.creates_processes_externally = False
39
    cluster_env.local_rank.return_value = 1
40
    launcher = _SubprocessScriptLauncher(cluster_env, num_processes=2, num_nodes=1)
41

42
    with pytest.raises(RuntimeError, match="attempted to launch new distributed processes with `local_rank > 0`"):
43
        launcher.launch(Mock())
44

45
    launcher.procs = [Mock()]  # there are already processes running
46
    with pytest.raises(RuntimeError, match="The launcher can only create subprocesses once"):
47
        launcher.launch(Mock())
48

49

50
@mock.patch("lightning.fabric.strategies.launchers.subprocess_script.subprocess.Popen")
51
@mock.patch("lightning.fabric.strategies.launchers.subprocess_script._ChildProcessObserver")
52
def test_subprocess_script_launcher_external_processes(_, popen_mock):
53
    cluster_env = Mock()
54
    cluster_env.creates_processes_externally = True
55
    function = Mock()
56
    launcher = _SubprocessScriptLauncher(cluster_env, num_processes=4, num_nodes=2)
57
    launcher.launch(function, "positional-arg", keyword_arg=0)
58
    function.assert_called_with("positional-arg", keyword_arg=0)
59
    popen_mock.assert_not_called()
60

61

62
@mock.patch("lightning.fabric.strategies.launchers.subprocess_script.subprocess.Popen")
63
@mock.patch("lightning.fabric.strategies.launchers.subprocess_script._ChildProcessObserver")
64
def test_subprocess_script_launcher_launch_processes(_, popen_mock):
65
    cluster_env = Mock()
66
    cluster_env.creates_processes_externally = False
67
    cluster_env.local_rank.return_value = 0
68
    cluster_env.main_address = "address"
69
    cluster_env.main_port = 1234
70

71
    function = Mock()
72
    launcher = _SubprocessScriptLauncher(cluster_env, num_processes=4, num_nodes=2)
73
    num_new_processes = launcher.num_processes - 1
74

75
    # launches n-1 new processes, the current one will participate too
76
    launcher.launch(function, "positional-arg", keyword_arg=0)
77

78
    calls = popen_mock.call_args_list
79
    assert len(calls) == num_new_processes
80

81
    # world size in child processes
82
    world_sizes = [int(calls[i][1]["env"]["WORLD_SIZE"]) for i in range(num_new_processes)]
83
    assert world_sizes == [launcher.num_processes * launcher.num_nodes] * num_new_processes
84

85
    # local rank in child processes
86
    local_ranks = [int(calls[i][1]["env"]["LOCAL_RANK"]) for i in range(num_new_processes)]
87
    assert local_ranks == list(range(1, num_new_processes + 1))
88

89
    # the current process
90
    assert int(os.environ["WORLD_SIZE"]) == launcher.num_processes * launcher.num_nodes
91
    assert int(os.environ["LOCAL_RANK"]) == 0
92

93

94
@pytest.mark.skipif(not _HYDRA_AVAILABLE, reason="hydra-core is required")
95
@mock.patch("lightning.fabric.strategies.launchers.subprocess_script.subprocess.Popen")
96
@mock.patch("lightning.fabric.strategies.launchers.subprocess_script._ChildProcessObserver")
97
def test_subprocess_script_launcher_hydra_in_use(_, popen_mock, monkeypatch):
98
    basic_command = Mock(return_value="basic_command")
99
    hydra_command = Mock(return_value=("hydra_command", "hydra_cwd"))
100
    monkeypatch.setattr(lightning.fabric.strategies.launchers.subprocess_script, "_basic_subprocess_cmd", basic_command)
101
    monkeypatch.setattr(lightning.fabric.strategies.launchers.subprocess_script, "_hydra_subprocess_cmd", hydra_command)
102

103
    def simulate_launch():
104
        cluster_env = Mock()
105
        cluster_env.creates_processes_externally = False
106
        cluster_env.local_rank.return_value = 0
107
        cluster_env.main_address = "address"
108
        cluster_env.main_port = 1234
109
        function = Mock()
110
        launcher = _SubprocessScriptLauncher(cluster_env, num_processes=4, num_nodes=2)
111
        launcher.launch(function)
112

113
    # when hydra not available
114
    monkeypatch.setattr(lightning.fabric.strategies.launchers.subprocess_script, "_HYDRA_AVAILABLE", False)
115
    simulate_launch()
116
    popen_mock.assert_called_with("basic_command", env=ANY, cwd=None)
117
    popen_mock.reset_mock()
118

119
    import hydra
120

121
    # when hydra available but not initialized
122
    monkeypatch.setattr(lightning.fabric.strategies.launchers.subprocess_script, "_HYDRA_AVAILABLE", True)
123
    HydraConfigMock = Mock()
124
    HydraConfigMock.initialized.return_value = False
125
    monkeypatch.setattr(hydra.core.hydra_config, "HydraConfig", HydraConfigMock)
126
    simulate_launch()
127
    popen_mock.assert_called_with("basic_command", env=ANY, cwd=None)
128
    popen_mock.reset_mock()
129

130
    # when hydra available and initialized
131
    monkeypatch.setattr(lightning.fabric.strategies.launchers.subprocess_script, "_HYDRA_AVAILABLE", True)
132
    HydraConfigMock = Mock()
133
    HydraConfigMock.initialized.return_value = True
134
    monkeypatch.setattr(hydra.core.hydra_config, "HydraConfig", HydraConfigMock)
135
    simulate_launch()
136
    popen_mock.assert_called_with("hydra_command", env=ANY, cwd="hydra_cwd")
137
    popen_mock.reset_mock()
138

139

140
@mock.patch("lightning.fabric.strategies.launchers.subprocess_script.os.kill")
141
@mock.patch("lightning.fabric.strategies.launchers.subprocess_script.time.sleep")
142
def test_child_process_observer(sleep_mock, os_kill_mock):
143
    # Case 1: All processes are running and did not exit yet
144
    processes = [Mock(returncode=None), Mock(returncode=None)]
145
    observer = _ChildProcessObserver(main_pid=1234, child_processes=processes)
146
    finished = observer._run()  # call _run() directly to simulate while loop
147
    assert not finished
148

149
    # Case 2: All processes have finished with exit code 0 (success)
150
    processes = [Mock(returncode=0), Mock(returncode=0)]
151
    observer = _ChildProcessObserver(main_pid=1234, child_processes=processes)
152
    finished = observer._run()  # call _run() directly to simulate while loop
153
    assert finished
154

155
    # Case 3: One process has finished with exit code 1 (failure)
156
    processes = [Mock(returncode=0), Mock(returncode=1)]
157
    observer = _ChildProcessObserver(main_pid=1234, child_processes=processes)
158
    finished = observer._run()  # call _run() directly to simulate while loop
159
    assert finished
160
    expected_signal = signal.SIGTERM if sys.platform == "win32" else signal.SIGKILL
161
    processes[0].send_signal.assert_called_once_with(expected_signal)
162
    processes[1].send_signal.assert_called_once_with(expected_signal)
163
    os_kill_mock.assert_called_once_with(1234, expected_signal)
164

165
    # The main routine stops
166
    observer = _ChildProcessObserver(main_pid=1234, child_processes=[Mock(), Mock()])
167
    observer._run = Mock()
168
    assert not observer._finished
169
    observer.run()
170
    assert observer._finished
171
    sleep_mock.assert_called_once_with(5)
172

173

174
@mock.patch("lightning.fabric.strategies.launchers.subprocess_script.subprocess.Popen")
175
@mock.patch("lightning.fabric.strategies.launchers.subprocess_script._ChildProcessObserver")
176
def test_validate_cluster_environment_user_settings(*_):
177
    """Test that the launcher calls into the cluster environment to validate the user settings."""
178
    cluster_env = Mock(validate_settings=Mock(side_effect=RuntimeError("test")))
179
    cluster_env.creates_processes_externally = True
180
    launcher = _SubprocessScriptLauncher(cluster_env, num_processes=2, num_nodes=1)
181

182
    with pytest.raises(RuntimeError, match="test"):
183
        launcher.launch(Mock())
184

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.