pytorch-lightning

Форк
0
97 строк · 3.8 Кб
1
import os
2
from unittest import mock
3
from unittest.mock import Mock
4

5
import lightning.fabric.utilities
6
import pytest
7
import torch
8
from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states
9

10

11
@mock.patch.dict(os.environ, clear=True)
12
def test_default_seed():
13
    """Test that the default seed is 0 when no seed provided and no environment variable set."""
14
    assert lightning.fabric.utilities.seed.seed_everything() == 0
15
    assert os.environ["PL_GLOBAL_SEED"] == "0"
16

17

18
@mock.patch.dict(os.environ, {}, clear=True)
19
def test_seed_stays_same_with_multiple_seed_everything_calls():
20
    """Ensure that after the initial seed everything, the seed stays the same for the same run."""
21
    with pytest.warns(UserWarning, match="No seed found"):
22
        lightning.fabric.utilities.seed.seed_everything()
23
    initial_seed = os.environ.get("PL_GLOBAL_SEED")
24

25
    with pytest.warns(None) as record:
26
        lightning.fabric.utilities.seed.seed_everything()
27
    assert not record  # does not warn
28
    seed = os.environ.get("PL_GLOBAL_SEED")
29

30
    assert initial_seed == seed
31

32

33
@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "2020"}, clear=True)
34
def test_correct_seed_with_environment_variable():
35
    """Ensure that the PL_GLOBAL_SEED environment is read."""
36
    assert lightning.fabric.utilities.seed.seed_everything() == 2020
37

38

39
@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "invalid"}, clear=True)
40
def test_invalid_seed():
41
    """Ensure that we still fix the seed even if an invalid seed is given."""
42
    with pytest.warns(UserWarning, match="Invalid seed found"):
43
        seed = lightning.fabric.utilities.seed.seed_everything()
44
    assert seed == 0
45

46

47
@mock.patch.dict(os.environ, {}, clear=True)
48
@pytest.mark.parametrize("seed", [10e9, -10e9])
49
def test_out_of_bounds_seed(seed):
50
    """Ensure that we still fix the seed even if an out-of-bounds seed is given."""
51
    with pytest.warns(UserWarning, match="is not in bounds"):
52
        actual = lightning.fabric.utilities.seed.seed_everything(seed)
53
    assert actual == 0
54

55

56
def test_reset_seed_no_op():
57
    """Test that the reset_seed function is a no-op when seed_everything() was not used."""
58
    assert "PL_GLOBAL_SEED" not in os.environ
59
    seed_before = torch.initial_seed()
60
    lightning.fabric.utilities.seed.reset_seed()
61
    assert torch.initial_seed() == seed_before
62
    assert "PL_GLOBAL_SEED" not in os.environ
63

64

65
@pytest.mark.parametrize("workers", [True, False])
66
def test_reset_seed_everything(workers):
67
    """Test that we can reset the seed to the initial value set by seed_everything()"""
68
    assert "PL_GLOBAL_SEED" not in os.environ
69
    assert "PL_SEED_WORKERS" not in os.environ
70

71
    lightning.fabric.utilities.seed.seed_everything(123, workers)
72
    before = torch.rand(1)
73
    assert os.environ["PL_GLOBAL_SEED"] == "123"
74
    assert os.environ["PL_SEED_WORKERS"] == str(int(workers))
75

76
    lightning.fabric.utilities.seed.reset_seed()
77
    after = torch.rand(1)
78
    assert os.environ["PL_GLOBAL_SEED"] == "123"
79
    assert os.environ["PL_SEED_WORKERS"] == str(int(workers))
80
    assert torch.allclose(before, after)
81

82

83
def test_backward_compatibility_rng_states_dict():
84
    """Test that an older rng_states_dict without the "torch.cuda" key does not crash."""
85
    states = _collect_rng_states()
86
    assert "torch.cuda" in states
87
    states.pop("torch.cuda")
88
    _set_rng_states(states)
89

90

91
@mock.patch("lightning.fabric.utilities.seed.torch.cuda.is_available", Mock(return_value=False))
92
@mock.patch("lightning.fabric.utilities.seed.torch.cuda.get_rng_state_all")
93
def test_collect_rng_states_if_cuda_init_fails(get_rng_state_all_mock):
94
    """Test that the `torch.cuda` rng states are only requested if CUDA is available."""
95
    get_rng_state_all_mock.side_effect = RuntimeError("The NVIDIA driver on your system is too old")
96
    states = _collect_rng_states()
97
    assert states["torch.cuda"] == []
98

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.