pytorch-lightning
129 строк · 4.8 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import logging
15import os
16from unittest import mock
17from unittest.mock import MagicMock
18
19import lightning.fabric.plugins.environments.mpi
20import pytest
21from lightning.fabric.plugins.environments import MPIEnvironment
22
23
24def test_dependencies(monkeypatch):
25"""Test that the MPI environment requires the `mpi4py` package."""
26monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", False)
27with pytest.raises(ModuleNotFoundError):
28MPIEnvironment()
29
30# pretend mpi4py is available
31monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True)
32with mock.patch.dict("sys.modules", {"mpi4py": MagicMock()}):
33MPIEnvironment()
34
35
36def test_detect(monkeypatch):
37"""Test the detection of an MPI environment configuration."""
38monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", False)
39assert not MPIEnvironment.detect()
40
41# pretend mpi4py is available
42monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True)
43mpi4py_mock = MagicMock()
44
45with mock.patch.dict("sys.modules", {"mpi4py": mpi4py_mock}):
46mpi4py_mock.MPI.COMM_WORLD.Get_size.return_value = 0
47assert not MPIEnvironment.detect()
48
49mpi4py_mock.MPI.COMM_WORLD.Get_size.return_value = 1
50assert not MPIEnvironment.detect()
51
52mpi4py_mock.MPI.COMM_WORLD.Get_size.return_value = 2
53assert MPIEnvironment.detect()
54
55
56@mock.patch.dict(os.environ, {}, clear=True)
57def test_default_attributes(monkeypatch):
58"""Test the default attributes when no environment variables are set."""
59# pretend mpi4py is available
60monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True)
61mpi4py_mock = MagicMock()
62with mock.patch.dict("sys.modules", {"mpi4py": mpi4py_mock}):
63env = MPIEnvironment()
64
65assert env._node_rank is None
66assert env._main_address is None
67assert env._main_port is None
68assert env.creates_processes_externally
69
70
71def test_init_local_comm(monkeypatch):
72"""Test that it can determine the node rank and local rank based on the hostnames of all participating nodes."""
73# pretend mpi4py is available
74monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True)
75mpi4py_mock = MagicMock()
76hostname_mock = MagicMock()
77
78mpi4py_mock.MPI.COMM_WORLD.Get_size.return_value = 4
79with mock.patch.dict("sys.modules", {"mpi4py": mpi4py_mock}), mock.patch("socket.gethostname", hostname_mock):
80env = MPIEnvironment()
81
82hostname_mock.return_value = "host1"
83env._comm_world.gather.return_value = ["host1", "host2"]
84env._comm_world.bcast.return_value = ["host1", "host2"]
85assert env.node_rank() == 0
86
87env._node_rank = None
88hostname_mock.return_value = "host2"
89env._comm_world.gather.return_value = None
90env._comm_world.bcast.return_value = ["host1", "host2"]
91assert env.node_rank() == 1
92
93assert env._comm_local is not None
94env._comm_local.Get_rank.return_value = 33
95assert env.local_rank() == 33
96
97
98def test_world_comm(monkeypatch):
99# pretend mpi4py is available
100monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True)
101mpi4py_mock = MagicMock()
102
103with mock.patch.dict("sys.modules", {"mpi4py": mpi4py_mock}):
104env = MPIEnvironment()
105
106env._comm_world.Get_size.return_value = 8
107assert env.world_size() == 8
108env._comm_world.Get_rank.return_value = 3
109assert env.global_rank() == 3
110
111
112def test_setters(monkeypatch, caplog):
113# pretend mpi4py is available
114monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True)
115mpi4py_mock = MagicMock()
116
117with mock.patch.dict("sys.modules", {"mpi4py": mpi4py_mock}):
118env = MPIEnvironment()
119
120# setter should be no-op
121with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"):
122env.set_global_rank(100)
123assert "setting global rank is not allowed" in caplog.text
124
125caplog.clear()
126
127with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"):
128env.set_world_size(100)
129assert "setting world size is not allowed" in caplog.text
130