pytorch-lightning

Форк
0
99 строк · 3.7 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
import os
15
from unittest import mock
16

17
import lightning.fabric
18
import pytest
19
from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1
20
from lightning.fabric.plugins.environments import XLAEnvironment
21

22
from tests_fabric.helpers.runif import RunIf
23

24

25
@RunIf(tpu=True)
26
# keep existing environment or else xla will default to pjrt
27
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
28
def test_default_attributes(monkeypatch):
29
    """Test the default attributes when no environment variables are set."""
30
    # calling these creates side effects in other tests
31
    if _XLA_GREATER_EQUAL_2_1:
32
        from torch_xla import runtime
33

34
        monkeypatch.setattr(runtime, "world_size", lambda: 2)
35
        monkeypatch.setattr(runtime, "global_ordinal", lambda: 0)
36
        monkeypatch.setattr(runtime, "local_ordinal", lambda: 0)
37
        monkeypatch.setattr(runtime, "host_index", lambda: 1)
38
    else:
39
        from torch_xla.experimental import pjrt
40

41
        monkeypatch.setattr(pjrt, "world_size", lambda: 2)
42
        monkeypatch.setattr(pjrt, "global_ordinal", lambda: 0)
43
        monkeypatch.setattr(pjrt, "local_ordinal", lambda: 0)
44
        os.environ["XRT_HOST_ORDINAL"] = "1"
45

46
    env = XLAEnvironment()
47
    assert not env.creates_processes_externally
48
    assert env.world_size() == 2
49
    assert env.global_rank() == 0
50
    assert env.local_rank() == 0
51
    assert env.node_rank() == 1
52

53
    with pytest.raises(NotImplementedError):
54
        _ = env.main_address
55
    with pytest.raises(NotImplementedError):
56
        _ = env.main_port
57

58

59
@RunIf(tpu=True)
60
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
61
def test_attributes_from_environment_variables(monkeypatch):
62
    """Test that the default cluster environment takes the attributes from the environment variables."""
63
    if _XLA_GREATER_EQUAL_2_1:
64
        from torch_xla import runtime
65

66
        monkeypatch.setattr(runtime, "world_size", lambda: 2)
67
        monkeypatch.setattr(runtime, "global_ordinal", lambda: 0)
68
        monkeypatch.setattr(runtime, "local_ordinal", lambda: 2)
69
        monkeypatch.setattr(runtime, "host_index", lambda: 1)
70
    else:
71
        from torch_xla.experimental import pjrt
72

73
        monkeypatch.setattr(pjrt, "world_size", lambda: 2)
74
        monkeypatch.setattr(pjrt, "global_ordinal", lambda: 0)
75
        monkeypatch.setattr(pjrt, "local_ordinal", lambda: 2)
76
        os.environ["XRT_HOST_ORDINAL"] = "1"
77

78
    env = XLAEnvironment()
79
    with pytest.raises(NotImplementedError):
80
        _ = env.main_address
81
    with pytest.raises(NotImplementedError):
82
        _ = env.main_port
83
    assert env.world_size() == 2
84
    assert env.global_rank() == 0
85
    assert env.local_rank() == 2
86
    assert env.node_rank() == 1
87
    env.set_global_rank(100)
88
    assert env.global_rank() == 0
89
    env.set_world_size(100)
90
    assert env.world_size() == 2
91

92

93
def test_detect(monkeypatch):
94
    """Test the detection of a xla environment configuration."""
95
    monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "is_available", lambda: False)
96
    assert not XLAEnvironment.detect()
97

98
    monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "is_available", lambda: True)
99
    assert XLAEnvironment.detect()
100

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.