pytorch

Форк
0
/
random.py 
199 строк · 7.0 Кб
1
# mypy: allow-untyped-defs
2
import contextlib
3
import warnings
4
from typing import Generator
5

6
import torch
7
from torch._C import default_generator
8

9

10
def set_rng_state(new_state: torch.Tensor) -> None:
11
    r"""Sets the random number generator state.
12

13
    .. note:: This function only works for CPU. For CUDA, please use
14
        :func:`torch.manual_seed`, which works for both CPU and CUDA.
15

16
    Args:
17
        new_state (torch.ByteTensor): The desired state
18
    """
19
    default_generator.set_state(new_state)
20

21

22
def get_rng_state() -> torch.Tensor:
23
    r"""Returns the random number generator state as a `torch.ByteTensor`.
24

25
    .. note:: The returned state is for the default generator on CPU only.
26

27
    See also: :func:`torch.random.fork_rng`.
28
    """
29
    return default_generator.get_state()
30

31

32
def manual_seed(seed) -> torch._C.Generator:
33
    r"""Sets the seed for generating random numbers on all devices. Returns a
34
    `torch.Generator` object.
35

36
    Args:
37
        seed (int): The desired seed. Value must be within the inclusive range
38
            `[-0x8000_0000_0000_0000, 0xffff_ffff_ffff_ffff]`. Otherwise, a RuntimeError
39
            is raised. Negative inputs are remapped to positive values with the formula
40
            `0xffff_ffff_ffff_ffff + seed`.
41
    """
42
    seed = int(seed)
43
    import torch.cuda
44

45
    if not torch.cuda._is_in_bad_fork():
46
        torch.cuda.manual_seed_all(seed)
47

48
    import torch.mps
49

50
    if not torch.mps._is_in_bad_fork():
51
        torch.mps.manual_seed(seed)
52

53
    import torch.xpu
54

55
    if not torch.xpu._is_in_bad_fork():
56
        torch.xpu.manual_seed_all(seed)
57

58
    _seed_custom_device(seed)
59

60
    return default_generator.manual_seed(seed)
61

62

63
def seed() -> int:
64
    r"""Sets the seed for generating random numbers to a non-deterministic
65
    random number on all devices. Returns a 64 bit number used to seed the RNG.
66
    """
67
    seed = default_generator.seed()
68
    import torch.cuda
69

70
    if not torch.cuda._is_in_bad_fork():
71
        torch.cuda.manual_seed_all(seed)
72

73
    import torch.mps
74

75
    if not torch.mps._is_in_bad_fork():
76
        torch.mps.manual_seed(seed)
77

78
    import torch.xpu
79

80
    if not torch.xpu._is_in_bad_fork():
81
        torch.xpu.manual_seed_all(seed)
82

83
    _seed_custom_device(seed)
84

85
    return seed
86

87

88
def _seed_custom_device(seed) -> None:
89
    r"""Sets the seed to generate random numbers for custom device.
90

91
    Args:
92
        seed (int): The desired seed.
93

94
    See [Note: support the custom device with privateuse1]
95
    """
96
    seed = int(seed)
97
    custom_backend_name = torch._C._get_privateuse1_backend_name()
98
    if hasattr(torch, custom_backend_name):
99
        custom_device_mod = getattr(torch, custom_backend_name)
100
        _bad_fork_name = "_is_in_bad_fork"
101
        _seed_all_name = "manual_seed_all"
102
        if hasattr(custom_device_mod, _bad_fork_name) and hasattr(
103
            custom_device_mod, _seed_all_name
104
        ):
105
            if not getattr(custom_device_mod, _bad_fork_name)():
106
                getattr(custom_device_mod, _seed_all_name)(seed)
107
        else:
108
            message = f"Set seed for `{custom_backend_name}` device does not take effect, please add API's "
109
            message += f"`{_bad_fork_name}` and `{_seed_all_name}` to `{custom_backend_name}` device module."
110
            warnings.warn(message, UserWarning, stacklevel=3)
111

112

113
def initial_seed() -> int:
114
    r"""Returns the initial seed for generating random numbers as a
115
    Python `long`.
116

117
    .. note:: The returned seed is for the default generator on CPU only.
118
    """
119
    return default_generator.initial_seed()
120

121

122
_fork_rng_warned_already = False
123

124

125
@contextlib.contextmanager
126
def fork_rng(
127
    devices=None,
128
    enabled=True,
129
    _caller="fork_rng",
130
    _devices_kw="devices",
131
    device_type="cuda",
132
) -> Generator:
133
    """
134
    Forks the RNG, so that when you return, the RNG is reset
135
    to the state that it was previously in.
136

137
    Args:
138
        devices (iterable of Device IDs): devices for which to fork
139
            the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates
140
            on all devices, but will emit a warning if your machine has a lot
141
            of devices, since this function will run very slowly in that case.
142
            If you explicitly specify devices, this warning will be suppressed
143
        enabled (bool): if ``False``, the RNG is not forked.  This is a convenience
144
            argument for easily disabling the context manager without having
145
            to delete it and unindent your Python code under it.
146
        device_type (str): device type str, default is `cuda`. As for custom device,
147
            see details in [Note: support the custom device with privateuse1]
148
    """
149

150
    device_type = torch.device(device_type).type
151
    device_mod = getattr(torch, device_type, None)
152
    if device_mod is None:
153
        raise RuntimeError(
154
            f"torch has no module of `{device_type}`, you should register "
155
            + "a module by `torch._register_device_module`."
156
        )
157
    global _fork_rng_warned_already
158

159
    # Internal arguments:
160
    #   _caller: the function which called fork_rng, which the user used
161
    #   _devices_kw: the devices keyword of _caller
162

163
    if not enabled:
164
        yield
165
        return
166

167
    if devices is None:
168
        num_devices = device_mod.device_count()
169
        if num_devices > 1 and not _fork_rng_warned_already:
170
            message = (
171
                f"{device_type.upper()} reports that you have {num_devices} available devices, and "
172
                f"you have used {_caller} without explicitly specifying which devices are being used. "
173
                f"For safety, we initialize *every* {device_type.upper()} device by default, which can "
174
                f"be quite slow if you have a lot of {device_type.upper()}s. If you know that you are only"
175
                f" making use of a few {device_type.upper()} devices, set the environment variable "
176
                f"{device_type.upper()}_VISIBLE_DEVICES or the '{_devices_kw}' keyword argument of {_caller} "
177
                "with the set of devices you are actually using. For example, if you are using CPU only, "
178
                "set device.upper()_VISIBLE_DEVICES= or devices=[]; if you are using device 0 only, "
179
                f"set {device_type.upper()}_VISIBLE_DEVICES=0 or devices=[0].  To initialize all devices "
180
                f"and suppress this warning, set the '{_devices_kw}' keyword argument to "
181
                f"`range(torch.{device_type}.device_count())`."
182
            )
183
            warnings.warn(message)
184
            _fork_rng_warned_already = True
185
        devices = list(range(num_devices))
186
    else:
187
        # Protect against user passing us a generator; we need to traverse this
188
        # multiple times but a generator will be exhausted upon first traversal
189
        devices = list(devices)
190

191
    cpu_rng_state = torch.get_rng_state()
192
    device_rng_states = [device_mod.get_rng_state(device) for device in devices]
193

194
    try:
195
        yield
196
    finally:
197
        torch.set_rng_state(cpu_rng_state)
198
        for device, device_rng_state in zip(devices, device_rng_states):
199
            device_mod.set_rng_state(device_rng_state, device)
200

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.