intel-extension-for-pytorch
252 строки · 8.3 Кб
1import torch2from typing import cast, Iterable, List, Union3from torch import Tensor4from .lazy_init import _lazy_init, _lazy_call5
6import contextlib7from typing import Generator8import warnings9
10__all__ = [11"get_rng_state",12"get_rng_state_all",13"set_rng_state",14"set_rng_state_all",15"manual_seed",16"manual_seed_all",17"seed",18"seed_all",19"initial_seed",20"fork_rng",21]
22
23
24def get_rng_state(device: Union[int, str, torch.device] = "xpu") -> Tensor:25r"""Returns the random number generator state of the specified GPU as a ByteTensor.26
27Args:
28device (torch.device or int, optional): The device to return the RNG state of.
29Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device).
30
31.. warning::
32This function eagerly initializes XPU.
33"""
34
35_lazy_init()36if isinstance(device, str):37device = torch.device(device)38elif isinstance(device, int):39device = torch.device("xpu", device)40idx = device.index41if idx is None:42idx = torch.xpu.current_device()43default_generator = torch.xpu.default_generators[idx]44return default_generator.get_state()45
46
47def get_rng_state_all() -> List[Tensor]:48r"""Returns a list of ByteTensor representing the random number states of all devices."""49
50results = []51for i in range(torch.xpu.device_count()):52results.append(get_rng_state(i))53return results54
55
56def set_rng_state(57new_state: Tensor, device: Union[int, str, torch.device] = "xpu"58) -> None:59r"""Sets the random number generator state of the specified GPU.60
61Args:
62new_state (torch.ByteTensor): The desired state
63device (torch.device or int, optional): The device to set the RNG state.
64Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device).
65"""
66new_state_copy = new_state.clone(memory_format=torch.contiguous_format)67if isinstance(device, str):68device = torch.device(device)69elif isinstance(device, int):70device = torch.device("xpu", device)71
72def cb():73idx = cast(torch.device, device).index74if idx is None:75idx = torch.xpu.current_device()76default_generator = torch.xpu.default_generators[idx]77default_generator.set_state(new_state_copy)78
79_lazy_call(cb)80
81
82def set_rng_state_all(new_states: Iterable[Tensor]) -> None:83r"""Sets the random number generator state of all devices.84
85Args:
86new_states (Iterable of torch.ByteTensor): The desired state for each device"""
87for i, state in enumerate(new_states):88set_rng_state(state, i)89
90
91def manual_seed(seed: int) -> None:92r"""Sets the seed for generating random numbers for the current GPU.93It's safe to call this function if XPU is not available; in that
94case, it is silently ignored.
95
96Args:
97seed (int): The desired seed.
98
99.. warning::
100If you are working with a multi-GPU model, this function is insufficient
101to get determinism. To seed all GPUs, use :func:`manual_seed_all`.
102"""
103seed = int(seed)104
105def cb():106idx = torch.xpu.current_device()107default_generator = torch.xpu.default_generators[idx]108default_generator.manual_seed(seed)109
110_lazy_call(cb)111
112
113def manual_seed_all(seed: int) -> None:114r"""Sets the seed for generating random numbers on all GPUs.115It's safe to call this function if XPU is not available; in that
116case, it is silently ignored.
117
118Args:
119seed (int): The desired seed.
120"""
121seed = int(seed)122
123def cb():124for i in range(torch.xpu.device_count()):125default_generator = torch.xpu.default_generators[i]126default_generator.manual_seed(seed)127
128_lazy_call(cb, seed_all=True)129
130
131def seed() -> None:132r"""Sets the seed for generating random numbers to a random number for the current GPU.133It's safe to call this function if XPU is not available; in that
134case, it is silently ignored.
135
136.. warning::
137If you are working with a multi-GPU model, this function will only initialize
138the seed on one GPU. To initialize all GPUs, use :func:`seed_all`.
139"""
140
141def cb():142idx = torch.xpu.current_device()143default_generator = torch.xpu.default_generators[idx]144default_generator.seed()145
146_lazy_call(cb)147
148
149def seed_all() -> None:150r"""Sets the seed for generating random numbers to a random number on all GPUs.151It's safe to call this function if XPU is not available; in that
152case, it is silently ignored.
153"""
154
155def cb():156random_seed = 0157seeded = False158for i in range(torch.xpu.device_count()):159default_generator = torch.xpu.default_generators[i]160if not seeded:161default_generator.seed()162random_seed = default_generator.initial_seed()163seeded = True164else:165default_generator.manual_seed(random_seed)166
167_lazy_call(cb)168
169
170def initial_seed() -> int:171r"""Returns the current random seed of the current GPU.172
173.. warning::
174This function eagerly initializes XPU.
175"""
176
177# lazy initialization occurs in current_device178idx = torch.xpu.current_device()179default_generator = torch.xpu.default_generators[idx]180return default_generator.initial_seed()181
182
183_fork_rng_warned_already = False184
185
186@contextlib.contextmanager187def fork_rng(188devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices"189) -> Generator:190"""191Forks the RNG, so that when you return, the RNG is reset
192to the state that it was previously in.
193
194Args:
195devices (iterable of XPU IDs): XPU devices for which to fork
196the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates
197on all devices, but will emit a warning if your machine has a lot
198of devices, since this function will run very slowly in that case.
199If you explicitly specify devices, this warning will be suppressed
200enabled (bool): if ``False``, the RNG is not forked. This is a convenience
201argument for easily disabling the context manager without having
202to delete it and unindent your Python code under it.
203"""
204
205global _fork_rng_warned_already206
207# Internal arguments:208# _caller: the function which called fork_rng, which the user used209# _devices_kw: the devices keyword of _caller210
211if not enabled:212yield213return214
215if devices is None:216num_devices = torch.xpu.device_count()217if num_devices > 1 and not _fork_rng_warned_already:218warnings.warn(219(220"XPU reports that you have {num_devices} available devices, and you "221"have used {caller} without explicitly specifying which devices are being used. "222"For safety, we initialize *every* XPU device by default, which "223"can be quite slow if you have a lot of GPUs. If you know that you are only "224"making use of a few XPU devices, set the environment variable XPU_VISIBLE_DEVICES "225"or the '{devices_kw}' keyword argument of {caller} with the set of devices "226"you are actually using. For example, if you are using CPU only, "227"set XPU_VISIBLE_DEVICES= or devices=[]; if you are using "228"GPU 0 only, set XPU_VISIBLE_DEVICES=0 or devices=[0]. To initialize "229"all devices and suppress this warning, set the '{devices_kw}' keyword argument "230"to `range(torch.xpu.device_count())`."231).format(232num_devices=num_devices, caller=_caller, devices_kw=_devices_kw233)234)235_fork_rng_warned_already = True236devices = list(range(num_devices))237else:238# Protect against user passing us a generator; we need to traverse this239# multiple times but a generator will be exhausted upon first traversal240devices = list(devices)241
242cpu_rng_state = torch.get_rng_state()243gpu_rng_states = []244for device in devices:245gpu_rng_states.append(torch.xpu.get_rng_state(device))246
247try:248yield249finally:250torch.set_rng_state(cpu_rng_state)251for device, gpu_rng_state in zip(devices, gpu_rng_states):252torch.xpu.set_rng_state(gpu_rng_state, device)253