pytorch-lightning
99 строк · 3.7 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import os
15from unittest import mock
16
17import lightning.fabric
18import pytest
19from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1
20from lightning.fabric.plugins.environments import XLAEnvironment
21
22from tests_fabric.helpers.runif import RunIf
23
24
25@RunIf(tpu=True)
26# keep existing environment or else xla will default to pjrt
27@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
28def test_default_attributes(monkeypatch):
29"""Test the default attributes when no environment variables are set."""
30# calling these creates side effects in other tests
31if _XLA_GREATER_EQUAL_2_1:
32from torch_xla import runtime
33
34monkeypatch.setattr(runtime, "world_size", lambda: 2)
35monkeypatch.setattr(runtime, "global_ordinal", lambda: 0)
36monkeypatch.setattr(runtime, "local_ordinal", lambda: 0)
37monkeypatch.setattr(runtime, "host_index", lambda: 1)
38else:
39from torch_xla.experimental import pjrt
40
41monkeypatch.setattr(pjrt, "world_size", lambda: 2)
42monkeypatch.setattr(pjrt, "global_ordinal", lambda: 0)
43monkeypatch.setattr(pjrt, "local_ordinal", lambda: 0)
44os.environ["XRT_HOST_ORDINAL"] = "1"
45
46env = XLAEnvironment()
47assert not env.creates_processes_externally
48assert env.world_size() == 2
49assert env.global_rank() == 0
50assert env.local_rank() == 0
51assert env.node_rank() == 1
52
53with pytest.raises(NotImplementedError):
54_ = env.main_address
55with pytest.raises(NotImplementedError):
56_ = env.main_port
57
58
59@RunIf(tpu=True)
60@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
61def test_attributes_from_environment_variables(monkeypatch):
62"""Test that the default cluster environment takes the attributes from the environment variables."""
63if _XLA_GREATER_EQUAL_2_1:
64from torch_xla import runtime
65
66monkeypatch.setattr(runtime, "world_size", lambda: 2)
67monkeypatch.setattr(runtime, "global_ordinal", lambda: 0)
68monkeypatch.setattr(runtime, "local_ordinal", lambda: 2)
69monkeypatch.setattr(runtime, "host_index", lambda: 1)
70else:
71from torch_xla.experimental import pjrt
72
73monkeypatch.setattr(pjrt, "world_size", lambda: 2)
74monkeypatch.setattr(pjrt, "global_ordinal", lambda: 0)
75monkeypatch.setattr(pjrt, "local_ordinal", lambda: 2)
76os.environ["XRT_HOST_ORDINAL"] = "1"
77
78env = XLAEnvironment()
79with pytest.raises(NotImplementedError):
80_ = env.main_address
81with pytest.raises(NotImplementedError):
82_ = env.main_port
83assert env.world_size() == 2
84assert env.global_rank() == 0
85assert env.local_rank() == 2
86assert env.node_rank() == 1
87env.set_global_rank(100)
88assert env.global_rank() == 0
89env.set_world_size(100)
90assert env.world_size() == 2
91
92
93def test_detect(monkeypatch):
94"""Test the detection of a xla environment configuration."""
95monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "is_available", lambda: False)
96assert not XLAEnvironment.detect()
97
98monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "is_available", lambda: True)
99assert XLAEnvironment.detect()
100