pytorch-lightning
107 строк · 4.7 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14from unittest import mock
15from unittest.mock import ANY, Mock
16
17import pytest
18import torch
19from lightning.fabric.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher
20
21from tests_fabric.helpers.runif import RunIf
22
23
24@RunIf(skip_windows=True)
25@pytest.mark.parametrize("start_method", ["fork", "forkserver"])
26def test_interactive_compatible(start_method):
27launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
28assert launcher.is_interactive_compatible == (start_method == "fork")
29
30
31@mock.patch("lightning.fabric.strategies.launchers.multiprocessing.mp.get_all_start_methods", return_value=[])
32def test_forking_on_unsupported_platform(_):
33with pytest.raises(ValueError, match="The start method 'fork' is not available on this platform"):
34_MultiProcessingLauncher(strategy=Mock(), start_method="fork")
35
36
37@pytest.mark.parametrize("start_method", ["spawn", pytest.param("fork", marks=RunIf(standalone=True))])
38@mock.patch("lightning.fabric.strategies.launchers.multiprocessing.mp")
39@mock.patch("lightning.fabric.strategies.launchers.multiprocessing._check_missing_main_guard")
40def test_start_method(_, mp_mock, start_method):
41mp_mock.get_all_start_methods.return_value = [start_method]
42launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
43launcher.launch(function=Mock())
44mp_mock.get_context.assert_called_with(start_method)
45mp_mock.start_processes.assert_called_with(
46ANY,
47args=ANY,
48nprocs=ANY,
49start_method=start_method,
50)
51
52
53@pytest.mark.parametrize("start_method", ["spawn", pytest.param("fork", marks=RunIf(standalone=True))])
54@mock.patch("lightning.fabric.strategies.launchers.multiprocessing.mp")
55@mock.patch("lightning.fabric.strategies.launchers.multiprocessing._check_missing_main_guard")
56def test_restore_globals(_, mp_mock, start_method):
57"""Test that we pass the global state snapshot to the worker function only if we are starting with 'spawn'."""
58mp_mock.get_all_start_methods.return_value = [start_method]
59launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
60launcher.launch(function=Mock())
61function_args = mp_mock.start_processes.call_args[1]["args"]
62if start_method == "spawn":
63assert len(function_args) == 5
64assert isinstance(function_args[4], _GlobalStateSnapshot)
65else:
66assert len(function_args) == 4
67
68
69@pytest.mark.usefixtures("reset_deterministic_algorithm")
70def test_global_state_snapshot():
71"""Test the capture() and restore() methods for the global state snapshot."""
72torch.use_deterministic_algorithms(True)
73torch.backends.cudnn.benchmark = False
74torch.manual_seed(123)
75
76# capture the state of globals
77snapshot = _GlobalStateSnapshot.capture()
78
79# simulate there is a process boundary and flags get reset here
80torch.use_deterministic_algorithms(False)
81torch.backends.cudnn.benchmark = True
82torch.manual_seed(321)
83
84# restore the state of globals
85snapshot.restore()
86assert torch.are_deterministic_algorithms_enabled()
87assert not torch.backends.cudnn.benchmark
88assert torch.initial_seed() == 123
89
90
91@pytest.mark.parametrize("start_method", ["fork", "forkserver"])
92@mock.patch("torch.cuda.is_initialized", return_value=True)
93@mock.patch("lightning.fabric.strategies.launchers.multiprocessing.mp")
94def test_check_for_bad_cuda_fork(mp_mock, _, start_method):
95mp_mock.get_all_start_methods.return_value = [start_method]
96launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
97with pytest.raises(RuntimeError, match="Lightning can't create new processes if CUDA is already initialized"):
98launcher.launch(function=Mock())
99
100
101def test_check_for_missing_main_guard():
102launcher = _MultiProcessingLauncher(strategy=Mock(), start_method="spawn")
103with mock.patch(
104"lightning.fabric.strategies.launchers.multiprocessing.mp.current_process",
105return_value=Mock(_inheriting=True), # pretend that main is importing itself
106), pytest.raises(RuntimeError, match="requires that your script guards the main"):
107launcher.launch(function=Mock())
108