pytorch-lightning
183 строки · 8.2 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import os
15import signal
16import sys
17from unittest import mock
18from unittest.mock import ANY, Mock
19
20import lightning.fabric
21import pytest
22from lightning.fabric.strategies.launchers.subprocess_script import (
23_HYDRA_AVAILABLE,
24_ChildProcessObserver,
25_SubprocessScriptLauncher,
26)
27
28
29def test_subprocess_script_launcher_interactive_compatible():
30launcher = _SubprocessScriptLauncher(Mock(), num_processes=2, num_nodes=1)
31assert not launcher.is_interactive_compatible
32
33
34@mock.patch("lightning.fabric.strategies.launchers.subprocess_script.subprocess.Popen")
35@mock.patch("lightning.fabric.strategies.launchers.subprocess_script._ChildProcessObserver")
36def test_subprocess_script_launcher_can_launch(*_):
37cluster_env = Mock()
38cluster_env.creates_processes_externally = False
39cluster_env.local_rank.return_value = 1
40launcher = _SubprocessScriptLauncher(cluster_env, num_processes=2, num_nodes=1)
41
42with pytest.raises(RuntimeError, match="attempted to launch new distributed processes with `local_rank > 0`"):
43launcher.launch(Mock())
44
45launcher.procs = [Mock()] # there are already processes running
46with pytest.raises(RuntimeError, match="The launcher can only create subprocesses once"):
47launcher.launch(Mock())
48
49
50@mock.patch("lightning.fabric.strategies.launchers.subprocess_script.subprocess.Popen")
51@mock.patch("lightning.fabric.strategies.launchers.subprocess_script._ChildProcessObserver")
52def test_subprocess_script_launcher_external_processes(_, popen_mock):
53cluster_env = Mock()
54cluster_env.creates_processes_externally = True
55function = Mock()
56launcher = _SubprocessScriptLauncher(cluster_env, num_processes=4, num_nodes=2)
57launcher.launch(function, "positional-arg", keyword_arg=0)
58function.assert_called_with("positional-arg", keyword_arg=0)
59popen_mock.assert_not_called()
60
61
62@mock.patch("lightning.fabric.strategies.launchers.subprocess_script.subprocess.Popen")
63@mock.patch("lightning.fabric.strategies.launchers.subprocess_script._ChildProcessObserver")
64def test_subprocess_script_launcher_launch_processes(_, popen_mock):
65cluster_env = Mock()
66cluster_env.creates_processes_externally = False
67cluster_env.local_rank.return_value = 0
68cluster_env.main_address = "address"
69cluster_env.main_port = 1234
70
71function = Mock()
72launcher = _SubprocessScriptLauncher(cluster_env, num_processes=4, num_nodes=2)
73num_new_processes = launcher.num_processes - 1
74
75# launches n-1 new processes, the current one will participate too
76launcher.launch(function, "positional-arg", keyword_arg=0)
77
78calls = popen_mock.call_args_list
79assert len(calls) == num_new_processes
80
81# world size in child processes
82world_sizes = [int(calls[i][1]["env"]["WORLD_SIZE"]) for i in range(num_new_processes)]
83assert world_sizes == [launcher.num_processes * launcher.num_nodes] * num_new_processes
84
85# local rank in child processes
86local_ranks = [int(calls[i][1]["env"]["LOCAL_RANK"]) for i in range(num_new_processes)]
87assert local_ranks == list(range(1, num_new_processes + 1))
88
89# the current process
90assert int(os.environ["WORLD_SIZE"]) == launcher.num_processes * launcher.num_nodes
91assert int(os.environ["LOCAL_RANK"]) == 0
92
93
94@pytest.mark.skipif(not _HYDRA_AVAILABLE, reason="hydra-core is required")
95@mock.patch("lightning.fabric.strategies.launchers.subprocess_script.subprocess.Popen")
96@mock.patch("lightning.fabric.strategies.launchers.subprocess_script._ChildProcessObserver")
97def test_subprocess_script_launcher_hydra_in_use(_, popen_mock, monkeypatch):
98basic_command = Mock(return_value="basic_command")
99hydra_command = Mock(return_value=("hydra_command", "hydra_cwd"))
100monkeypatch.setattr(lightning.fabric.strategies.launchers.subprocess_script, "_basic_subprocess_cmd", basic_command)
101monkeypatch.setattr(lightning.fabric.strategies.launchers.subprocess_script, "_hydra_subprocess_cmd", hydra_command)
102
103def simulate_launch():
104cluster_env = Mock()
105cluster_env.creates_processes_externally = False
106cluster_env.local_rank.return_value = 0
107cluster_env.main_address = "address"
108cluster_env.main_port = 1234
109function = Mock()
110launcher = _SubprocessScriptLauncher(cluster_env, num_processes=4, num_nodes=2)
111launcher.launch(function)
112
113# when hydra not available
114monkeypatch.setattr(lightning.fabric.strategies.launchers.subprocess_script, "_HYDRA_AVAILABLE", False)
115simulate_launch()
116popen_mock.assert_called_with("basic_command", env=ANY, cwd=None)
117popen_mock.reset_mock()
118
119import hydra
120
121# when hydra available but not initialized
122monkeypatch.setattr(lightning.fabric.strategies.launchers.subprocess_script, "_HYDRA_AVAILABLE", True)
123HydraConfigMock = Mock()
124HydraConfigMock.initialized.return_value = False
125monkeypatch.setattr(hydra.core.hydra_config, "HydraConfig", HydraConfigMock)
126simulate_launch()
127popen_mock.assert_called_with("basic_command", env=ANY, cwd=None)
128popen_mock.reset_mock()
129
130# when hydra available and initialized
131monkeypatch.setattr(lightning.fabric.strategies.launchers.subprocess_script, "_HYDRA_AVAILABLE", True)
132HydraConfigMock = Mock()
133HydraConfigMock.initialized.return_value = True
134monkeypatch.setattr(hydra.core.hydra_config, "HydraConfig", HydraConfigMock)
135simulate_launch()
136popen_mock.assert_called_with("hydra_command", env=ANY, cwd="hydra_cwd")
137popen_mock.reset_mock()
138
139
140@mock.patch("lightning.fabric.strategies.launchers.subprocess_script.os.kill")
141@mock.patch("lightning.fabric.strategies.launchers.subprocess_script.time.sleep")
142def test_child_process_observer(sleep_mock, os_kill_mock):
143# Case 1: All processes are running and did not exit yet
144processes = [Mock(returncode=None), Mock(returncode=None)]
145observer = _ChildProcessObserver(main_pid=1234, child_processes=processes)
146finished = observer._run() # call _run() directly to simulate while loop
147assert not finished
148
149# Case 2: All processes have finished with exit code 0 (success)
150processes = [Mock(returncode=0), Mock(returncode=0)]
151observer = _ChildProcessObserver(main_pid=1234, child_processes=processes)
152finished = observer._run() # call _run() directly to simulate while loop
153assert finished
154
155# Case 3: One process has finished with exit code 1 (failure)
156processes = [Mock(returncode=0), Mock(returncode=1)]
157observer = _ChildProcessObserver(main_pid=1234, child_processes=processes)
158finished = observer._run() # call _run() directly to simulate while loop
159assert finished
160expected_signal = signal.SIGTERM if sys.platform == "win32" else signal.SIGKILL
161processes[0].send_signal.assert_called_once_with(expected_signal)
162processes[1].send_signal.assert_called_once_with(expected_signal)
163os_kill_mock.assert_called_once_with(1234, expected_signal)
164
165# The main routine stops
166observer = _ChildProcessObserver(main_pid=1234, child_processes=[Mock(), Mock()])
167observer._run = Mock()
168assert not observer._finished
169observer.run()
170assert observer._finished
171sleep_mock.assert_called_once_with(5)
172
173
174@mock.patch("lightning.fabric.strategies.launchers.subprocess_script.subprocess.Popen")
175@mock.patch("lightning.fabric.strategies.launchers.subprocess_script._ChildProcessObserver")
176def test_validate_cluster_environment_user_settings(*_):
177"""Test that the launcher calls into the cluster environment to validate the user settings."""
178cluster_env = Mock(validate_settings=Mock(side_effect=RuntimeError("test")))
179cluster_env.creates_processes_externally = True
180launcher = _SubprocessScriptLauncher(cluster_env, num_processes=2, num_nodes=1)
181
182with pytest.raises(RuntimeError, match="test"):
183launcher.launch(Mock())
184