intel-extension-for-pytorch

Форк
0
252 строки · 8.3 Кб
1
import torch
2
from typing import cast, Iterable, List, Union
3
from torch import Tensor
4
from .lazy_init import _lazy_init, _lazy_call
5

6
import contextlib
7
from typing import Generator
8
import warnings
9

10
__all__ = [
11
    "get_rng_state",
12
    "get_rng_state_all",
13
    "set_rng_state",
14
    "set_rng_state_all",
15
    "manual_seed",
16
    "manual_seed_all",
17
    "seed",
18
    "seed_all",
19
    "initial_seed",
20
    "fork_rng",
21
]
22

23

24
def get_rng_state(device: Union[int, str, torch.device] = "xpu") -> Tensor:
25
    r"""Returns the random number generator state of the specified GPU as a ByteTensor.
26

27
    Args:
28
        device (torch.device or int, optional): The device to return the RNG state of.
29
            Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device).
30

31
    .. warning::
32
        This function eagerly initializes XPU.
33
    """
34

35
    _lazy_init()
36
    if isinstance(device, str):
37
        device = torch.device(device)
38
    elif isinstance(device, int):
39
        device = torch.device("xpu", device)
40
    idx = device.index
41
    if idx is None:
42
        idx = torch.xpu.current_device()
43
    default_generator = torch.xpu.default_generators[idx]
44
    return default_generator.get_state()
45

46

47
def get_rng_state_all() -> List[Tensor]:
48
    r"""Returns a list of ByteTensor representing the random number states of all devices."""
49

50
    results = []
51
    for i in range(torch.xpu.device_count()):
52
        results.append(get_rng_state(i))
53
    return results
54

55

56
def set_rng_state(
57
    new_state: Tensor, device: Union[int, str, torch.device] = "xpu"
58
) -> None:
59
    r"""Sets the random number generator state of the specified GPU.
60

61
    Args:
62
        new_state (torch.ByteTensor): The desired state
63
        device (torch.device or int, optional): The device to set the RNG state.
64
            Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device).
65
    """
66
    new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
67
    if isinstance(device, str):
68
        device = torch.device(device)
69
    elif isinstance(device, int):
70
        device = torch.device("xpu", device)
71

72
    def cb():
73
        idx = cast(torch.device, device).index
74
        if idx is None:
75
            idx = torch.xpu.current_device()
76
        default_generator = torch.xpu.default_generators[idx]
77
        default_generator.set_state(new_state_copy)
78

79
    _lazy_call(cb)
80

81

82
def set_rng_state_all(new_states: Iterable[Tensor]) -> None:
83
    r"""Sets the random number generator state of all devices.
84

85
    Args:
86
        new_states (Iterable of torch.ByteTensor): The desired state for each device"""
87
    for i, state in enumerate(new_states):
88
        set_rng_state(state, i)
89

90

91
def manual_seed(seed: int) -> None:
92
    r"""Sets the seed for generating random numbers for the current GPU.
93
    It's safe to call this function if XPU is not available; in that
94
    case, it is silently ignored.
95

96
    Args:
97
        seed (int): The desired seed.
98

99
    .. warning::
100
        If you are working with a multi-GPU model, this function is insufficient
101
        to get determinism.  To seed all GPUs, use :func:`manual_seed_all`.
102
    """
103
    seed = int(seed)
104

105
    def cb():
106
        idx = torch.xpu.current_device()
107
        default_generator = torch.xpu.default_generators[idx]
108
        default_generator.manual_seed(seed)
109

110
    _lazy_call(cb)
111

112

113
def manual_seed_all(seed: int) -> None:
114
    r"""Sets the seed for generating random numbers on all GPUs.
115
    It's safe to call this function if XPU is not available; in that
116
    case, it is silently ignored.
117

118
    Args:
119
        seed (int): The desired seed.
120
    """
121
    seed = int(seed)
122

123
    def cb():
124
        for i in range(torch.xpu.device_count()):
125
            default_generator = torch.xpu.default_generators[i]
126
            default_generator.manual_seed(seed)
127

128
    _lazy_call(cb, seed_all=True)
129

130

131
def seed() -> None:
132
    r"""Sets the seed for generating random numbers to a random number for the current GPU.
133
    It's safe to call this function if XPU is not available; in that
134
    case, it is silently ignored.
135

136
    .. warning::
137
        If you are working with a multi-GPU model, this function will only initialize
138
        the seed on one GPU.  To initialize all GPUs, use :func:`seed_all`.
139
    """
140

141
    def cb():
142
        idx = torch.xpu.current_device()
143
        default_generator = torch.xpu.default_generators[idx]
144
        default_generator.seed()
145

146
    _lazy_call(cb)
147

148

149
def seed_all() -> None:
150
    r"""Sets the seed for generating random numbers to a random number on all GPUs.
151
    It's safe to call this function if XPU is not available; in that
152
    case, it is silently ignored.
153
    """
154

155
    def cb():
156
        random_seed = 0
157
        seeded = False
158
        for i in range(torch.xpu.device_count()):
159
            default_generator = torch.xpu.default_generators[i]
160
            if not seeded:
161
                default_generator.seed()
162
                random_seed = default_generator.initial_seed()
163
                seeded = True
164
            else:
165
                default_generator.manual_seed(random_seed)
166

167
    _lazy_call(cb)
168

169

170
def initial_seed() -> int:
171
    r"""Returns the current random seed of the current GPU.
172

173
    .. warning::
174
        This function eagerly initializes XPU.
175
    """
176

177
    # lazy initialization occurs in current_device
178
    idx = torch.xpu.current_device()
179
    default_generator = torch.xpu.default_generators[idx]
180
    return default_generator.initial_seed()
181

182

183
_fork_rng_warned_already = False
184

185

186
@contextlib.contextmanager
187
def fork_rng(
188
    devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices"
189
) -> Generator:
190
    """
191
    Forks the RNG, so that when you return, the RNG is reset
192
    to the state that it was previously in.
193

194
    Args:
195
        devices (iterable of XPU IDs): XPU devices for which to fork
196
            the RNG.  CPU RNG state is always forked.  By default, :meth:`fork_rng` operates
197
            on all devices, but will emit a warning if your machine has a lot
198
            of devices, since this function will run very slowly in that case.
199
            If you explicitly specify devices, this warning will be suppressed
200
        enabled (bool): if ``False``, the RNG is not forked.  This is a convenience
201
            argument for easily disabling the context manager without having
202
            to delete it and unindent your Python code under it.
203
    """
204

205
    global _fork_rng_warned_already
206

207
    # Internal arguments:
208
    #   _caller: the function which called fork_rng, which the user used
209
    #   _devices_kw: the devices keyword of _caller
210

211
    if not enabled:
212
        yield
213
        return
214

215
    if devices is None:
216
        num_devices = torch.xpu.device_count()
217
        if num_devices > 1 and not _fork_rng_warned_already:
218
            warnings.warn(
219
                (
220
                    "XPU reports that you have {num_devices} available devices, and you "
221
                    "have used {caller} without explicitly specifying which devices are being used. "
222
                    "For safety, we initialize *every* XPU device by default, which "
223
                    "can be quite slow if you have a lot of GPUs.  If you know that you are only "
224
                    "making use of a few XPU devices, set the environment variable XPU_VISIBLE_DEVICES "
225
                    "or the '{devices_kw}' keyword argument of {caller} with the set of devices "
226
                    "you are actually using.  For example, if you are using CPU only, "
227
                    "set XPU_VISIBLE_DEVICES= or devices=[]; if you are using "
228
                    "GPU 0 only, set XPU_VISIBLE_DEVICES=0 or devices=[0].  To initialize "
229
                    "all devices and suppress this warning, set the '{devices_kw}' keyword argument "
230
                    "to `range(torch.xpu.device_count())`."
231
                ).format(
232
                    num_devices=num_devices, caller=_caller, devices_kw=_devices_kw
233
                )
234
            )
235
            _fork_rng_warned_already = True
236
        devices = list(range(num_devices))
237
    else:
238
        # Protect against user passing us a generator; we need to traverse this
239
        # multiple times but a generator will be exhausted upon first traversal
240
        devices = list(devices)
241

242
    cpu_rng_state = torch.get_rng_state()
243
    gpu_rng_states = []
244
    for device in devices:
245
        gpu_rng_states.append(torch.xpu.get_rng_state(device))
246

247
    try:
248
        yield
249
    finally:
250
        torch.set_rng_state(cpu_rng_state)
251
        for device, gpu_rng_state in zip(devices, gpu_rng_states):
252
            torch.xpu.set_rng_state(gpu_rng_state, device)
253

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.