pytorch-lightning

Форк
0
181 строка · 6.7 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
import logging
15
import os
16
import shutil
17
import sys
18
from unittest import mock
19

20
import pytest
21
from lightning.fabric.plugins.environments import SLURMEnvironment
22
from lightning.fabric.utilities.warnings import PossibleUserWarning
23
from lightning_utilities.test.warning import no_warning_call
24

25
from tests_fabric.helpers.runif import RunIf
26

27

28
@mock.patch.dict(os.environ, {}, clear=True)
29
def test_default_attributes():
30
    """Test the default attributes when no environment variables are set."""
31
    env = SLURMEnvironment()
32
    assert env.creates_processes_externally
33
    assert env.main_address == "127.0.0.1"
34
    assert env.main_port == 12910
35
    assert env.job_name() is None
36
    assert env.job_id() is None
37

38
    with pytest.raises(KeyError):
39
        # world size is required to be passed as env variable
40
        env.world_size()
41
    with pytest.raises(KeyError):
42
        # local rank is required to be passed as env variable
43
        env.local_rank()
44
    with pytest.raises(KeyError):
45
        # node_rank is required to be passed as env variable
46
        env.node_rank()
47

48

49
@mock.patch.dict(
50
    os.environ,
51
    {
52
        "SLURM_NODELIST": "1.1.1.1, 1.1.1.2",
53
        "SLURM_JOB_ID": "0001234",
54
        "SLURM_NTASKS": "20",
55
        "SLURM_NTASKS_PER_NODE": "10",
56
        "SLURM_LOCALID": "2",
57
        "SLURM_PROCID": "1",
58
        "SLURM_NODEID": "3",
59
        "SLURM_JOB_NAME": "JOB",
60
    },
61
)
62
def test_attributes_from_environment_variables(caplog):
63
    """Test that the SLURM cluster environment takes the attributes from the environment variables."""
64
    env = SLURMEnvironment()
65
    assert env.auto_requeue is True
66
    assert env.main_address == "1.1.1.1"
67
    assert env.main_port == 15000 + 1234
68
    assert env.job_id() == int("0001234")
69
    assert env.world_size() == 20
70
    assert env.global_rank() == 1
71
    assert env.local_rank() == 2
72
    assert env.node_rank() == 3
73
    assert env.job_name() == "JOB"
74
    # setter should be no-op
75
    with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"):
76
        env.set_global_rank(100)
77
    assert env.global_rank() == 1
78
    assert "setting global rank is not allowed" in caplog.text
79

80
    caplog.clear()
81

82
    with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"):
83
        env.set_world_size(100)
84
    assert env.world_size() == 20
85
    assert "setting world size is not allowed" in caplog.text
86

87

88
@pytest.mark.parametrize(
89
    ("slurm_node_list", "expected"),
90
    [
91
        ("127.0.0.1", "127.0.0.1"),
92
        ("alpha", "alpha"),
93
        ("alpha,beta,gamma", "alpha"),
94
        ("alpha beta gamma", "alpha"),
95
        ("1.2.3.[100-110]", "1.2.3.100"),
96
        ("1.2.3.[089, 100-110]", "1.2.3.089"),
97
        ("host[22]", "host22"),
98
        ("host[1,5-9]", "host1"),
99
        ("host[5-9,1]", "host5"),
100
        ("alpha, host[5-9], gamma", "alpha"),
101
        ("alpha[3,1], beta", "alpha3"),
102
    ],
103
)
104
def test_main_address_from_slurm_node_list(slurm_node_list, expected):
105
    """Test extracting the main node from different formats for the SLURM_NODELIST."""
106
    with mock.patch.dict(os.environ, {"SLURM_NODELIST": slurm_node_list}):
107
        env = SLURMEnvironment()
108
        assert env.main_address == expected
109

110

111
def test_main_address_and_port_from_env_variable():
112
    env = SLURMEnvironment()
113
    with mock.patch.dict(os.environ, {"MASTER_ADDR": "1.2.3.4", "MASTER_PORT": "1234"}):
114
        assert env.main_address == "1.2.3.4"
115
        assert env.main_port == 1234
116

117

118
def test_detect():
119
    """Test the detection of a SLURM environment configuration."""
120
    with mock.patch.dict(os.environ, {}, clear=True):
121
        assert not SLURMEnvironment.detect()
122

123
    with mock.patch.dict(os.environ, {"SLURM_NTASKS": "2"}):
124
        assert SLURMEnvironment.detect()
125

126
    with mock.patch.dict(os.environ, {"SLURM_JOB_NAME": "bash"}):
127
        assert not SLURMEnvironment.detect()
128

129
    with mock.patch.dict(os.environ, {"SLURM_JOB_NAME": "interactive"}):
130
        assert not SLURMEnvironment.detect()
131

132

133
@RunIf(skip_windows=True)
134
@pytest.mark.skipif(shutil.which("srun") is not None, reason="must run on a machine where srun is not available")
135
def test_srun_available_and_not_used(monkeypatch):
136
    """Test that a warning is emitted if Lightning suspects the user forgot to run their script with `srun`."""
137
    monkeypatch.setattr(sys, "argv", ["train.py", "--lr", "0.01"])
138
    expected = "`srun` .* available .* but is not used. HINT: .* srun python train.py --lr 0.01"
139

140
    # pretend `srun` is available
141
    with mock.patch("lightning.fabric.plugins.environments.slurm.shutil.which", return_value="/usr/bin/srun"):
142
        with pytest.warns(PossibleUserWarning, match=expected):
143
            SLURMEnvironment()
144

145
        with pytest.warns(PossibleUserWarning, match=expected):
146
            SLURMEnvironment.detect()
147

148
    # no warning if `srun` is unavailable
149
    with no_warning_call(PossibleUserWarning, match=expected):
150
        SLURMEnvironment()
151
        assert not SLURMEnvironment.detect()
152

153

154
def test_srun_variable_validation():
155
    """Test that we raise useful errors when `srun` variables are misconfigured."""
156
    with mock.patch.dict(os.environ, {"SLURM_NTASKS": "1"}):
157
        SLURMEnvironment()
158
    with mock.patch.dict(os.environ, {"SLURM_NTASKS": "2"}), pytest.raises(
159
        RuntimeError, match="You set `--ntasks=2` in your SLURM"
160
    ):
161
        SLURMEnvironment()
162

163

164
@mock.patch.dict(os.environ, {"SLURM_NTASKS_PER_NODE": "4", "SLURM_NNODES": "2"})
165
def test_validate_user_settings():
166
    """Test that the environment can validate the number of devices and nodes set in Fabric/Trainer."""
167
    env = SLURMEnvironment()
168
    env.validate_settings(num_devices=4, num_nodes=2)
169

170
    with pytest.raises(ValueError, match="the number of tasks per node configured .* does not match"):
171
        env.validate_settings(num_devices=2, num_nodes=2)
172

173
    with pytest.raises(ValueError, match="the number of nodes configured in SLURM .* does not match"):
174
        env.validate_settings(num_devices=4, num_nodes=1)
175

176
    # in interactive mode, validation is skipped becauses processes get launched by Fabric/Trainer, not SLURM
177
    with mock.patch(
178
        "lightning.fabric.plugins.environments.slurm.SLURMEnvironment.job_name", return_value="interactive"
179
    ):
180
        env = SLURMEnvironment()
181
        env.validate_settings(num_devices=4, num_nodes=1)  # no error
182

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.