pytorch-lightning

Форк
0
129 строк · 4.8 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
import logging
15
import os
16
from unittest import mock
17
from unittest.mock import MagicMock
18

19
import lightning.fabric.plugins.environments.mpi
20
import pytest
21
from lightning.fabric.plugins.environments import MPIEnvironment
22

23

24
def test_dependencies(monkeypatch):
25
    """Test that the MPI environment requires the `mpi4py` package."""
26
    monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", False)
27
    with pytest.raises(ModuleNotFoundError):
28
        MPIEnvironment()
29

30
    # pretend mpi4py is available
31
    monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True)
32
    with mock.patch.dict("sys.modules", {"mpi4py": MagicMock()}):
33
        MPIEnvironment()
34

35

36
def test_detect(monkeypatch):
37
    """Test the detection of an MPI environment configuration."""
38
    monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", False)
39
    assert not MPIEnvironment.detect()
40

41
    # pretend mpi4py is available
42
    monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True)
43
    mpi4py_mock = MagicMock()
44

45
    with mock.patch.dict("sys.modules", {"mpi4py": mpi4py_mock}):
46
        mpi4py_mock.MPI.COMM_WORLD.Get_size.return_value = 0
47
        assert not MPIEnvironment.detect()
48

49
        mpi4py_mock.MPI.COMM_WORLD.Get_size.return_value = 1
50
        assert not MPIEnvironment.detect()
51

52
        mpi4py_mock.MPI.COMM_WORLD.Get_size.return_value = 2
53
        assert MPIEnvironment.detect()
54

55

56
@mock.patch.dict(os.environ, {}, clear=True)
57
def test_default_attributes(monkeypatch):
58
    """Test the default attributes when no environment variables are set."""
59
    # pretend mpi4py is available
60
    monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True)
61
    mpi4py_mock = MagicMock()
62
    with mock.patch.dict("sys.modules", {"mpi4py": mpi4py_mock}):
63
        env = MPIEnvironment()
64

65
    assert env._node_rank is None
66
    assert env._main_address is None
67
    assert env._main_port is None
68
    assert env.creates_processes_externally
69

70

71
def test_init_local_comm(monkeypatch):
72
    """Test that it can determine the node rank and local rank based on the hostnames of all participating nodes."""
73
    # pretend mpi4py is available
74
    monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True)
75
    mpi4py_mock = MagicMock()
76
    hostname_mock = MagicMock()
77

78
    mpi4py_mock.MPI.COMM_WORLD.Get_size.return_value = 4
79
    with mock.patch.dict("sys.modules", {"mpi4py": mpi4py_mock}), mock.patch("socket.gethostname", hostname_mock):
80
        env = MPIEnvironment()
81

82
        hostname_mock.return_value = "host1"
83
        env._comm_world.gather.return_value = ["host1", "host2"]
84
        env._comm_world.bcast.return_value = ["host1", "host2"]
85
        assert env.node_rank() == 0
86

87
        env._node_rank = None
88
        hostname_mock.return_value = "host2"
89
        env._comm_world.gather.return_value = None
90
        env._comm_world.bcast.return_value = ["host1", "host2"]
91
        assert env.node_rank() == 1
92

93
        assert env._comm_local is not None
94
        env._comm_local.Get_rank.return_value = 33
95
        assert env.local_rank() == 33
96

97

98
def test_world_comm(monkeypatch):
99
    # pretend mpi4py is available
100
    monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True)
101
    mpi4py_mock = MagicMock()
102

103
    with mock.patch.dict("sys.modules", {"mpi4py": mpi4py_mock}):
104
        env = MPIEnvironment()
105

106
        env._comm_world.Get_size.return_value = 8
107
        assert env.world_size() == 8
108
        env._comm_world.Get_rank.return_value = 3
109
        assert env.global_rank() == 3
110

111

112
def test_setters(monkeypatch, caplog):
113
    # pretend mpi4py is available
114
    monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True)
115
    mpi4py_mock = MagicMock()
116

117
    with mock.patch.dict("sys.modules", {"mpi4py": mpi4py_mock}):
118
        env = MPIEnvironment()
119

120
    # setter should be no-op
121
    with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"):
122
        env.set_global_rank(100)
123
    assert "setting global rank is not allowed" in caplog.text
124

125
    caplog.clear()
126

127
    with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"):
128
        env.set_world_size(100)
129
    assert "setting world size is not allowed" in caplog.text
130

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.