pytorch

Форк
0
/
__init__.py 
487 строк · 15.1 Кб
1
# mypy: allow-untyped-defs
2
r"""
3
This package introduces support for the XPU backend, specifically tailored for
4
Intel GPU optimization.
5

6
This package is lazily initialized, so you can always import it, and use
7
:func:`is_available()` to determine if your system supports XPU.
8
"""
9
import threading
10
import traceback
11
from functools import lru_cache
12
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
13

14
import torch
15
import torch._C
16
from torch import device as _device
17
from torch._utils import _dummy_type, _LazySeedTracker
18

19
from ._utils import _get_device_index
20
from .streams import Event, Stream
21

22

23
_initialized = False
24
_tls = threading.local()
25
_initialization_lock = threading.Lock()
26
_queued_calls: List[
27
    Tuple[Callable[[], None], List[str]]
28
] = []  # don't invoke these until initialization occurs
29
_is_in_bad_fork = getattr(torch._C, "_xpu_isInBadFork", lambda: False)
30
_device_t = Union[_device, str, int, None]
31
_lazy_seed_tracker = _LazySeedTracker()
32
default_generators: Tuple[torch._C.Generator] = ()  # type: ignore[assignment]
33

34

35
def _is_compiled() -> bool:
36
    r"""Return true if compile with XPU support."""
37
    return torch._C._has_xpu
38

39

40
if _is_compiled():
41
    _XpuDeviceProperties = torch._C._XpuDeviceProperties
42
    _exchange_device = torch._C._xpu_exchangeDevice
43
    _maybe_exchange_device = torch._C._xpu_maybeExchangeDevice
44
else:
45
    # Define dummy if PyTorch was compiled without XPU
46
    _XpuDeviceProperties = _dummy_type("_XpuDeviceProperties")  # type: ignore[assignment, misc]
47

48
    def _exchange_device(device: int) -> int:
49
        raise NotImplementedError("PyTorch was compiled without XPU support")
50

51
    def _maybe_exchange_device(device: int) -> int:
52
        raise NotImplementedError("PyTorch was compiled without XPU support")
53

54

55
@lru_cache(maxsize=1)
56
def device_count() -> int:
57
    r"""Return the number of XPU device available."""
58
    if not _is_compiled():
59
        return 0
60
    return torch._C._xpu_getDeviceCount()
61

62

63
def is_available() -> bool:
64
    r"""Return a bool indicating if XPU is currently available."""
65
    # This function nerver throws.
66
    return device_count() > 0
67

68

69
def is_bf16_supported():
70
    r"""Return a bool indicating if the current XPU device supports dtype bfloat16."""
71
    return True
72

73

74
def is_initialized():
75
    r"""Return whether PyTorch's XPU state has been initialized."""
76
    return _initialized and not _is_in_bad_fork()
77

78

79
def _lazy_call(callable, **kwargs):
80
    if is_initialized():
81
        callable()
82
    else:
83
        global _lazy_seed_tracker
84
        if kwargs.get("seed_all", False):
85
            _lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack())
86
        elif kwargs.get("seed", False):
87
            _lazy_seed_tracker.queue_seed(callable, traceback.format_stack())
88
        else:
89
            # Don't store the actual traceback to avoid memory cycle
90
            _queued_calls.append((callable, traceback.format_stack()))
91

92

93
def init():
94
    r"""Initialize PyTorch's XPU state.
95
    This is a Python API about lazy initialization that avoids initializing
96
    XPU until the first time it is accessed. Does nothing if the XPU state is
97
    already initialized.
98
    """
99
    _lazy_init()
100

101

102
def _lazy_init():
103
    global _initialized, _queued_calls
104
    if is_initialized() or hasattr(_tls, "is_initializing"):
105
        return
106
    with _initialization_lock:
107
        # This test was was protected via GIL. Double-check whether XPU has
108
        # already been initialized.
109
        if is_initialized():
110
            return
111
        # Stop promptly upon encountering a bad fork error.
112
        if _is_in_bad_fork():
113
            raise RuntimeError(
114
                "Cannot re-initialize XPU in forked subprocess. To use XPU with "
115
                "multiprocessing, you must use the 'spawn' start method"
116
            )
117
        if not _is_compiled():
118
            raise AssertionError("Torch not compiled with XPU enabled")
119
        # This function inits XPU backend and detects bad fork processing.
120
        torch._C._xpu_init()
121
        # Some of the queued calls may reentrantly call _lazy_init(); We need to
122
        # just return without initializing in that case.
123
        _tls.is_initializing = True
124

125
        for calls in _lazy_seed_tracker.get_calls():
126
            if calls:
127
                _queued_calls.append(calls)
128

129
        try:
130
            for queued_call, orig_traceback in _queued_calls:
131
                try:
132
                    queued_call()
133
                except Exception as e:
134
                    msg = (
135
                        f"XPU call failed lazily at initialization with error: {str(e)}\n\n"
136
                        f"XPU call was originally invoked at:\n\n{''.join(orig_traceback)}"
137
                    )
138
                    raise Exception(msg) from e  # noqa: TRY002
139
        finally:
140
            delattr(_tls, "is_initializing")
141
        _initialized = True
142

143

144
class _DeviceGuard:
145
    def __init__(self, index: int):
146
        self.idx = index
147
        self.prev_idx = -1
148

149
    def __enter__(self):
150
        self.prev_idx = torch.xpu._exchange_device(self.idx)
151

152
    def __exit__(self, type: Any, value: Any, traceback: Any):
153
        self.idx = torch.xpu._maybe_exchange_device(self.prev_idx)
154
        return False
155

156

157
class device:
158
    r"""Context-manager that changes the selected device.
159

160
    Args:
161
        device (torch.device or int or str): device index to select. It's a no-op if
162
            this argument is a negative integer or ``None``.
163
    """
164

165
    def __init__(self, device: Any):
166
        self.idx = _get_device_index(device, optional=True)
167
        self.prev_idx = -1
168

169
    def __enter__(self):
170
        self.prev_idx = torch.xpu._exchange_device(self.idx)
171

172
    def __exit__(self, type: Any, value: Any, traceback: Any):
173
        self.idx = torch.xpu._maybe_exchange_device(self.prev_idx)
174
        return False
175

176

177
class device_of(device):
178
    r"""Context-manager that changes the current device to that of given object.
179

180
    You can use both tensors and storages as arguments. If a given object is
181
    not allocated on a XPU, this is a no-op.
182

183
    Args:
184
        obj (Tensor or Storage): object allocated on the selected device.
185
    """
186

187
    def __init__(self, obj):
188
        idx = obj.get_device() if obj.is_xpu else -1
189
        super().__init__(idx)
190

191

192
def set_device(device: _device_t) -> None:
193
    r"""Set the current device.
194

195
    Args:
196
        device (torch.device or int or str): selected device. This function is a
197
            no-op if this argument is negative.
198
    """
199
    _lazy_init()
200
    device = _get_device_index(device)
201
    if device >= 0:
202
        torch._C._xpu_setDevice(device)
203

204

205
def get_device_name(device: Optional[_device_t] = None) -> str:
206
    r"""Get the name of a device.
207

208
    Args:
209
        device (torch.device or int or str, optional): device for which to
210
            return the name. This function is a no-op if this argument is a
211
            negative integer. It uses the current device, given by :func:`~torch.xpu.current_device`,
212
            if :attr:`device` is ``None`` (default).
213

214
    Returns:
215
        str: the name of the device
216
    """
217
    return get_device_properties(device).name
218

219

220
@lru_cache(None)
221
def get_device_capability(device: Optional[_device_t] = None) -> Dict[str, Any]:
222
    r"""Get the xpu capability of a device.
223

224
    Args:
225
        device (torch.device or int or str, optional): device for which to
226
            return the device capability. This function is a no-op if this
227
            argument is a negative integer. It uses the current device, given by
228
            :func:`~torch.xpu.current_device`, if :attr:`device` is ``None``
229
            (default).
230

231
    Returns:
232
        Dict[str, Any]: the xpu capability dictionary of the device
233
    """
234
    props = get_device_properties(device)
235
    return {
236
        prop: getattr(props, prop) for prop in dir(props) if not prop.startswith("__")
237
    }
238

239

240
def get_device_properties(device: Optional[_device_t] = None) -> _XpuDeviceProperties:
241
    r"""Get the properties of a device.
242

243
    Args:
244
        device (torch.device or int or str): device for which to return the
245
            properties of the device.
246

247
    Returns:
248
        _XpuDeviceProperties: the properties of the device
249
    """
250
    _lazy_init()
251
    device = _get_device_index(device, optional=True)
252
    if device < 0 or device >= device_count():
253
        raise AssertionError("Invalid device index")
254
    return _get_device_properties(device)  # type: ignore[name-defined]  # noqa: F821
255

256

257
def current_device() -> int:
258
    r"""Return the index of a currently selected device."""
259
    _lazy_init()
260
    return torch._C._xpu_getDevice()
261

262

263
def _get_device(device: Union[int, str, torch.device]) -> torch.device:
264
    r"""Return the torch.device type object from the passed in device.
265

266
    Args:
267
        device (torch.device or int or str): selected device.
268
    """
269
    if isinstance(device, str):
270
        device = torch.device(device)
271
    elif isinstance(device, int):
272
        device = torch.device("xpu", device)
273
    return device
274

275

276
class StreamContext:
277
    r"""Context-manager that selects a given stream.
278

279
    All XPU kernels queued within its context will be enqueued on a selected
280
    stream.
281

282
    Args:
283
        Stream (Stream): selected stream. This manager is a no-op if it's
284
            ``None``.
285
    .. note:: Streams are per-device.
286
    """
287
    cur_stream: Optional["torch.xpu.Stream"]
288

289
    def __init__(self, stream: Optional["torch.xpu.Stream"]):
290
        self.stream = stream
291
        self.idx = _get_device_index(None, True)
292
        if self.idx is None:
293
            self.idx = -1
294

295
    def __enter__(self):
296
        cur_stream = self.stream
297
        if cur_stream is None or self.idx == -1:
298
            return
299
        self.src_prev_stream = torch.xpu.current_stream(None)
300

301
        # If the stream is not on the current device, then set the current stream on the device
302
        if self.src_prev_stream.device != cur_stream.device:
303
            with device(cur_stream.device):
304
                self.dst_prev_stream = torch.xpu.current_stream(cur_stream.device)
305
        torch.xpu.set_stream(cur_stream)
306

307
    def __exit__(self, type: Any, value: Any, traceback: Any):
308
        cur_stream = self.stream
309
        if cur_stream is None or self.idx == -1:
310
            return
311

312
        # Reset the stream on the original device and destination device
313
        if self.src_prev_stream.device != cur_stream.device:
314
            torch.xpu.set_stream(self.dst_prev_stream)
315
        torch.xpu.set_stream(self.src_prev_stream)
316

317

318
def stream(stream: Optional["torch.xpu.Stream"]) -> StreamContext:
319
    r"""Wrap around the Context-manager StreamContext that selects a given stream.
320

321
    Arguments:
322
        stream (Stream): selected stream. This manager is a no-op if it's ``None``.
323
    """
324
    return StreamContext(stream)
325

326

327
def _set_stream_by_id(stream_id, device_index, device_type):
328
    r"""set stream specified by the stream id, device index and device type
329

330
    Args: stream_id (int): not visible to the user, used to assigned to the specific stream.
331
          device_index (int): selected device index.
332
          device_type (int): selected device type.
333
    """
334
    torch._C._xpu_setStream(
335
        stream_id=stream_id,
336
        device_index=device_index,
337
        device_type=device_type,
338
    )
339

340

341
def set_stream(stream: Stream):
342
    r"""Set the current stream.This is a wrapper API to set the stream.
343
        Usage of this function is discouraged in favor of the ``stream``
344
        context manager.
345

346
    Args:
347
        stream (Stream): selected stream. This function is a no-op
348
            if this argument is ``None``.
349
    """
350
    if stream is None:
351
        return
352
    _lazy_init()
353
    _set_stream_by_id(
354
        stream_id=stream.stream_id,
355
        device_index=stream.device_index,
356
        device_type=stream.device_type,
357
    )
358

359

360
def current_stream(device: Optional[_device_t] = None) -> Stream:
361
    r"""Return the currently selected :class:`Stream` for a given device.
362

363
    Args:
364
        device (torch.device or int, optional): selected device. Returns
365
            the currently selected :class:`Stream` for the current device, given
366
            by :func:`~torch.xpu.current_device`, if :attr:`device` is ``None``
367
            (default).
368
    """
369
    _lazy_init()
370
    streamdata = torch._C._xpu_getCurrentStream(
371
        _get_device_index(device, optional=True)
372
    )
373
    return Stream(
374
        stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]
375
    )
376

377

378
def synchronize(device: _device_t = None) -> None:
379
    r"""Wait for all kernels in all streams on a XPU device to complete.
380

381
    Args:
382
        device (torch.device or int, optional): device for which to synchronize.
383
            It uses the current device, given by :func:`~torch.xpu.current_device`,
384
            if :attr:`device` is ``None`` (default).
385
    """
386
    _lazy_init()
387
    device = _get_device_index(device, optional=True)
388
    return torch._C._xpu_synchronize(device)
389

390

391
def empty_cache() -> None:
392
    r"""Release all unoccupied cached memory currently held by the caching
393
    allocator so that those can be used in other XPU application.
394

395
    .. note::
396
        :func:`~torch.xpu.empty_cache` doesn't increase the amount of XPU
397
        memory available for PyTorch. However, it may help reduce fragmentation
398
        of XPU memory in certain cases.
399
    """
400
    if is_initialized():
401
        torch._C._xpu_emptyCache()
402

403

404
def _get_generator(device: torch.device) -> torch._C.Generator:
405
    r"""Return the XPU Generator object for the given device.
406

407
    Args:
408
        device (torch.device): selected device.
409
    """
410
    idx = device.index
411
    if idx is None:
412
        idx = current_device()
413
    return torch.xpu.default_generators[idx]
414

415

416
def _set_rng_state_offset(
417
    offset: int, device: Union[int, str, torch.device] = "xpu"
418
) -> None:
419
    r"""Set the random number generator state offset of the specified GPU.
420

421
    Args:
422
        offset (int): The desired offset
423
        device (torch.device or int, optional): The device to set the RNG state.
424
            Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device).
425
    """
426
    final_device = _get_device(device)
427

428
    def cb():
429
        default_generator = _get_generator(final_device)
430
        default_generator.set_offset(offset)
431

432
    _lazy_call(cb)
433

434

435
def _get_rng_state_offset(device: Union[int, str, torch.device] = "xpu") -> int:
436
    r"""Return the random number generator state offset of the specified GPU.
437

438
    Args:
439
        device (torch.device or int, optional): The device to return the RNG state offset of.
440
            Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device).
441

442
    .. warning::
443
        This function eagerly initializes XPU.
444
    """
445
    _lazy_init()
446
    final_device = _get_device(device)
447
    default_generator = _get_generator(final_device)
448
    return default_generator.get_offset()
449

450

451
from .random import *  # noqa: F403
452

453

454
__all__ = [
455
    "Event",
456
    "Stream",
457
    "StreamContext",
458
    "current_device",
459
    "current_stream",
460
    "default_generators",
461
    "device",
462
    "device_of",
463
    "device_count",
464
    "empty_cache",
465
    "get_device_capability",
466
    "get_device_name",
467
    "get_device_properties",
468
    "get_rng_state",
469
    "get_rng_state_all",
470
    "get_stream",
471
    "init",
472
    "initial_seed",
473
    "is_available",
474
    "is_bf16_supported",
475
    "is_initialized",
476
    "manual_seed",
477
    "manual_seed_all",
478
    "seed",
479
    "seed_all",
480
    "set_device",
481
    "set_rng_state",
482
    "set_rng_state_all",
483
    "set_stream",
484
    "stream",
485
    "streams",
486
    "synchronize",
487
]
488

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.