pytorch-lightning
97 строк · 3.8 Кб
1import os
2from unittest import mock
3from unittest.mock import Mock
4
5import lightning.fabric.utilities
6import pytest
7import torch
8from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states
9
10
11@mock.patch.dict(os.environ, clear=True)
12def test_default_seed():
13"""Test that the default seed is 0 when no seed provided and no environment variable set."""
14assert lightning.fabric.utilities.seed.seed_everything() == 0
15assert os.environ["PL_GLOBAL_SEED"] == "0"
16
17
18@mock.patch.dict(os.environ, {}, clear=True)
19def test_seed_stays_same_with_multiple_seed_everything_calls():
20"""Ensure that after the initial seed everything, the seed stays the same for the same run."""
21with pytest.warns(UserWarning, match="No seed found"):
22lightning.fabric.utilities.seed.seed_everything()
23initial_seed = os.environ.get("PL_GLOBAL_SEED")
24
25with pytest.warns(None) as record:
26lightning.fabric.utilities.seed.seed_everything()
27assert not record # does not warn
28seed = os.environ.get("PL_GLOBAL_SEED")
29
30assert initial_seed == seed
31
32
33@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "2020"}, clear=True)
34def test_correct_seed_with_environment_variable():
35"""Ensure that the PL_GLOBAL_SEED environment is read."""
36assert lightning.fabric.utilities.seed.seed_everything() == 2020
37
38
39@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "invalid"}, clear=True)
40def test_invalid_seed():
41"""Ensure that we still fix the seed even if an invalid seed is given."""
42with pytest.warns(UserWarning, match="Invalid seed found"):
43seed = lightning.fabric.utilities.seed.seed_everything()
44assert seed == 0
45
46
47@mock.patch.dict(os.environ, {}, clear=True)
48@pytest.mark.parametrize("seed", [10e9, -10e9])
49def test_out_of_bounds_seed(seed):
50"""Ensure that we still fix the seed even if an out-of-bounds seed is given."""
51with pytest.warns(UserWarning, match="is not in bounds"):
52actual = lightning.fabric.utilities.seed.seed_everything(seed)
53assert actual == 0
54
55
56def test_reset_seed_no_op():
57"""Test that the reset_seed function is a no-op when seed_everything() was not used."""
58assert "PL_GLOBAL_SEED" not in os.environ
59seed_before = torch.initial_seed()
60lightning.fabric.utilities.seed.reset_seed()
61assert torch.initial_seed() == seed_before
62assert "PL_GLOBAL_SEED" not in os.environ
63
64
65@pytest.mark.parametrize("workers", [True, False])
66def test_reset_seed_everything(workers):
67"""Test that we can reset the seed to the initial value set by seed_everything()"""
68assert "PL_GLOBAL_SEED" not in os.environ
69assert "PL_SEED_WORKERS" not in os.environ
70
71lightning.fabric.utilities.seed.seed_everything(123, workers)
72before = torch.rand(1)
73assert os.environ["PL_GLOBAL_SEED"] == "123"
74assert os.environ["PL_SEED_WORKERS"] == str(int(workers))
75
76lightning.fabric.utilities.seed.reset_seed()
77after = torch.rand(1)
78assert os.environ["PL_GLOBAL_SEED"] == "123"
79assert os.environ["PL_SEED_WORKERS"] == str(int(workers))
80assert torch.allclose(before, after)
81
82
83def test_backward_compatibility_rng_states_dict():
84"""Test that an older rng_states_dict without the "torch.cuda" key does not crash."""
85states = _collect_rng_states()
86assert "torch.cuda" in states
87states.pop("torch.cuda")
88_set_rng_states(states)
89
90
91@mock.patch("lightning.fabric.utilities.seed.torch.cuda.is_available", Mock(return_value=False))
92@mock.patch("lightning.fabric.utilities.seed.torch.cuda.get_rng_state_all")
93def test_collect_rng_states_if_cuda_init_fails(get_rng_state_all_mock):
94"""Test that the `torch.cuda` rng states are only requested if CUDA is available."""
95get_rng_state_all_mock.side_effect = RuntimeError("The NVIDIA driver on your system is too old")
96states = _collect_rng_states()
97assert states["torch.cuda"] == []
98