4
from typing import Generator
7
from torch._C import default_generator
10
def set_rng_state(new_state: torch.Tensor) -> None:
11
r"""Sets the random number generator state.
13
.. note:: This function only works for CPU. For CUDA, please use
14
:func:`torch.manual_seed`, which works for both CPU and CUDA.
17
new_state (torch.ByteTensor): The desired state
19
default_generator.set_state(new_state)
22
def get_rng_state() -> torch.Tensor:
23
r"""Returns the random number generator state as a `torch.ByteTensor`.
25
.. note:: The returned state is for the default generator on CPU only.
27
See also: :func:`torch.random.fork_rng`.
29
return default_generator.get_state()
32
def manual_seed(seed) -> torch._C.Generator:
33
r"""Sets the seed for generating random numbers on all devices. Returns a
34
`torch.Generator` object.
37
seed (int): The desired seed. Value must be within the inclusive range
38
`[-0x8000_0000_0000_0000, 0xffff_ffff_ffff_ffff]`. Otherwise, a RuntimeError
39
is raised. Negative inputs are remapped to positive values with the formula
40
`0xffff_ffff_ffff_ffff + seed`.
45
if not torch.cuda._is_in_bad_fork():
46
torch.cuda.manual_seed_all(seed)
50
if not torch.mps._is_in_bad_fork():
51
torch.mps.manual_seed(seed)
55
if not torch.xpu._is_in_bad_fork():
56
torch.xpu.manual_seed_all(seed)
58
_seed_custom_device(seed)
60
return default_generator.manual_seed(seed)
64
r"""Sets the seed for generating random numbers to a non-deterministic
65
random number on all devices. Returns a 64 bit number used to seed the RNG.
67
seed = default_generator.seed()
70
if not torch.cuda._is_in_bad_fork():
71
torch.cuda.manual_seed_all(seed)
75
if not torch.mps._is_in_bad_fork():
76
torch.mps.manual_seed(seed)
80
if not torch.xpu._is_in_bad_fork():
81
torch.xpu.manual_seed_all(seed)
83
_seed_custom_device(seed)
88
def _seed_custom_device(seed) -> None:
89
r"""Sets the seed to generate random numbers for custom device.
92
seed (int): The desired seed.
94
See [Note: support the custom device with privateuse1]
97
custom_backend_name = torch._C._get_privateuse1_backend_name()
98
if hasattr(torch, custom_backend_name):
99
custom_device_mod = getattr(torch, custom_backend_name)
100
_bad_fork_name = "_is_in_bad_fork"
101
_seed_all_name = "manual_seed_all"
102
if hasattr(custom_device_mod, _bad_fork_name) and hasattr(
103
custom_device_mod, _seed_all_name
105
if not getattr(custom_device_mod, _bad_fork_name)():
106
getattr(custom_device_mod, _seed_all_name)(seed)
108
message = f"Set seed for `{custom_backend_name}` device does not take effect, please add API's "
109
message += f"`{_bad_fork_name}` and `{_seed_all_name}` to `{custom_backend_name}` device module."
110
warnings.warn(message, UserWarning, stacklevel=3)
113
def initial_seed() -> int:
114
r"""Returns the initial seed for generating random numbers as a
117
.. note:: The returned seed is for the default generator on CPU only.
119
return default_generator.initial_seed()
122
_fork_rng_warned_already = False
125
@contextlib.contextmanager
130
_devices_kw="devices",
134
Forks the RNG, so that when you return, the RNG is reset
135
to the state that it was previously in.
138
devices (iterable of Device IDs): devices for which to fork
139
the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates
140
on all devices, but will emit a warning if your machine has a lot
141
of devices, since this function will run very slowly in that case.
142
If you explicitly specify devices, this warning will be suppressed
143
enabled (bool): if ``False``, the RNG is not forked. This is a convenience
144
argument for easily disabling the context manager without having
145
to delete it and unindent your Python code under it.
146
device_type (str): device type str, default is `cuda`. As for custom device,
147
see details in [Note: support the custom device with privateuse1]
150
device_type = torch.device(device_type).type
151
device_mod = getattr(torch, device_type, None)
152
if device_mod is None:
154
f"torch has no module of `{device_type}`, you should register "
155
+ "a module by `torch._register_device_module`."
157
global _fork_rng_warned_already
168
num_devices = device_mod.device_count()
169
if num_devices > 1 and not _fork_rng_warned_already:
171
f"{device_type.upper()} reports that you have {num_devices} available devices, and "
172
f"you have used {_caller} without explicitly specifying which devices are being used. "
173
f"For safety, we initialize *every* {device_type.upper()} device by default, which can "
174
f"be quite slow if you have a lot of {device_type.upper()}s. If you know that you are only"
175
f" making use of a few {device_type.upper()} devices, set the environment variable "
176
f"{device_type.upper()}_VISIBLE_DEVICES or the '{_devices_kw}' keyword argument of {_caller} "
177
"with the set of devices you are actually using. For example, if you are using CPU only, "
178
"set device.upper()_VISIBLE_DEVICES= or devices=[]; if you are using device 0 only, "
179
f"set {device_type.upper()}_VISIBLE_DEVICES=0 or devices=[0]. To initialize all devices "
180
f"and suppress this warning, set the '{_devices_kw}' keyword argument to "
181
f"`range(torch.{device_type}.device_count())`."
183
warnings.warn(message)
184
_fork_rng_warned_already = True
185
devices = list(range(num_devices))
189
devices = list(devices)
191
cpu_rng_state = torch.get_rng_state()
192
device_rng_states = [device_mod.get_rng_state(device) for device in devices]
197
torch.set_rng_state(cpu_rng_state)
198
for device, device_rng_state in zip(devices, device_rng_states):
199
device_mod.set_rng_state(device_rng_state, device)