pytorch-lightning
79 строк · 2.6 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import logging15import os16from unittest import mock17
18import pytest19from lightning.fabric.plugins.environments import KubeflowEnvironment20
21
22@mock.patch.dict(os.environ, {}, clear=True)23def test_default_attributes():24"""Test the default attributes when no environment variables are set."""25env = KubeflowEnvironment()26assert env.creates_processes_externally27
28with pytest.raises(KeyError):29# MASTER_ADDR is required30env.main_address31with pytest.raises(KeyError):32# MASTER_PORT is required33env.main_port34with pytest.raises(KeyError):35# WORLD_SIZE is required36env.world_size()37with pytest.raises(KeyError):38# RANK is required39env.global_rank()40assert env.local_rank() == 041
42
43@mock.patch.dict(44os.environ,45{46"KUBERNETES_PORT": "tcp://127.0.0.1:443",47"MASTER_ADDR": "1.2.3.4",48"MASTER_PORT": "500",49"WORLD_SIZE": "20",50"RANK": "1",51},52)
53def test_attributes_from_environment_variables(caplog):54"""Test that the torchelastic cluster environment takes the attributes from the environment variables."""55env = KubeflowEnvironment()56assert env.main_address == "1.2.3.4"57assert env.main_port == 50058assert env.world_size() == 2059assert env.global_rank() == 160assert env.local_rank() == 061assert env.node_rank() == 162# setter should be no-op63with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"):64env.set_global_rank(100)65assert env.global_rank() == 166assert "setting global rank is not allowed" in caplog.text67
68caplog.clear()69
70with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"):71env.set_world_size(100)72assert env.world_size() == 2073assert "setting world size is not allowed" in caplog.text74
75
76def test_detect_kubeflow():77"""Test that the KubeflowEnvironment does not support auto-detection."""78with pytest.raises(NotImplementedError, match="can't be detected automatically"):79KubeflowEnvironment.detect()80