3
This package introduces support for the XPU backend, specifically tailored for
6
This package is lazily initialized, so you can always import it, and use
7
:func:`is_available()` to determine if your system supports XPU.
11
from functools import lru_cache
12
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
16
from torch import device as _device
17
from torch._utils import _dummy_type, _LazySeedTracker
19
from ._utils import _get_device_index
20
from .streams import Event, Stream
24
_tls = threading.local()
25
_initialization_lock = threading.Lock()
27
Tuple[Callable[[], None], List[str]]
29
_is_in_bad_fork = getattr(torch._C, "_xpu_isInBadFork", lambda: False)
30
_device_t = Union[_device, str, int, None]
31
_lazy_seed_tracker = _LazySeedTracker()
32
default_generators: Tuple[torch._C.Generator] = ()
35
def _is_compiled() -> bool:
36
r"""Return true if compile with XPU support."""
37
return torch._C._has_xpu
41
_XpuDeviceProperties = torch._C._XpuDeviceProperties
42
_exchange_device = torch._C._xpu_exchangeDevice
43
_maybe_exchange_device = torch._C._xpu_maybeExchangeDevice
46
_XpuDeviceProperties = _dummy_type("_XpuDeviceProperties")
48
def _exchange_device(device: int) -> int:
49
raise NotImplementedError("PyTorch was compiled without XPU support")
51
def _maybe_exchange_device(device: int) -> int:
52
raise NotImplementedError("PyTorch was compiled without XPU support")
56
def device_count() -> int:
57
r"""Return the number of XPU device available."""
58
if not _is_compiled():
60
return torch._C._xpu_getDeviceCount()
63
def is_available() -> bool:
64
r"""Return a bool indicating if XPU is currently available."""
66
return device_count() > 0
69
def is_bf16_supported():
70
r"""Return a bool indicating if the current XPU device supports dtype bfloat16."""
75
r"""Return whether PyTorch's XPU state has been initialized."""
76
return _initialized and not _is_in_bad_fork()
79
def _lazy_call(callable, **kwargs):
83
global _lazy_seed_tracker
84
if kwargs.get("seed_all", False):
85
_lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack())
86
elif kwargs.get("seed", False):
87
_lazy_seed_tracker.queue_seed(callable, traceback.format_stack())
90
_queued_calls.append((callable, traceback.format_stack()))
94
r"""Initialize PyTorch's XPU state.
95
This is a Python API about lazy initialization that avoids initializing
96
XPU until the first time it is accessed. Does nothing if the XPU state is
103
global _initialized, _queued_calls
104
if is_initialized() or hasattr(_tls, "is_initializing"):
106
with _initialization_lock:
112
if _is_in_bad_fork():
114
"Cannot re-initialize XPU in forked subprocess. To use XPU with "
115
"multiprocessing, you must use the 'spawn' start method"
117
if not _is_compiled():
118
raise AssertionError("Torch not compiled with XPU enabled")
123
_tls.is_initializing = True
125
for calls in _lazy_seed_tracker.get_calls():
127
_queued_calls.append(calls)
130
for queued_call, orig_traceback in _queued_calls:
133
except Exception as e:
135
f"XPU call failed lazily at initialization with error: {str(e)}\n\n"
136
f"XPU call was originally invoked at:\n\n{''.join(orig_traceback)}"
138
raise Exception(msg) from e
140
delattr(_tls, "is_initializing")
145
def __init__(self, index: int):
150
self.prev_idx = torch.xpu._exchange_device(self.idx)
152
def __exit__(self, type: Any, value: Any, traceback: Any):
153
self.idx = torch.xpu._maybe_exchange_device(self.prev_idx)
158
r"""Context-manager that changes the selected device.
161
device (torch.device or int or str): device index to select. It's a no-op if
162
this argument is a negative integer or ``None``.
165
def __init__(self, device: Any):
166
self.idx = _get_device_index(device, optional=True)
170
self.prev_idx = torch.xpu._exchange_device(self.idx)
172
def __exit__(self, type: Any, value: Any, traceback: Any):
173
self.idx = torch.xpu._maybe_exchange_device(self.prev_idx)
177
class device_of(device):
178
r"""Context-manager that changes the current device to that of given object.
180
You can use both tensors and storages as arguments. If a given object is
181
not allocated on a XPU, this is a no-op.
184
obj (Tensor or Storage): object allocated on the selected device.
187
def __init__(self, obj):
188
idx = obj.get_device() if obj.is_xpu else -1
189
super().__init__(idx)
192
def set_device(device: _device_t) -> None:
193
r"""Set the current device.
196
device (torch.device or int or str): selected device. This function is a
197
no-op if this argument is negative.
200
device = _get_device_index(device)
202
torch._C._xpu_setDevice(device)
205
def get_device_name(device: Optional[_device_t] = None) -> str:
206
r"""Get the name of a device.
209
device (torch.device or int or str, optional): device for which to
210
return the name. This function is a no-op if this argument is a
211
negative integer. It uses the current device, given by :func:`~torch.xpu.current_device`,
212
if :attr:`device` is ``None`` (default).
215
str: the name of the device
217
return get_device_properties(device).name
221
def get_device_capability(device: Optional[_device_t] = None) -> Dict[str, Any]:
222
r"""Get the xpu capability of a device.
225
device (torch.device or int or str, optional): device for which to
226
return the device capability. This function is a no-op if this
227
argument is a negative integer. It uses the current device, given by
228
:func:`~torch.xpu.current_device`, if :attr:`device` is ``None``
232
Dict[str, Any]: the xpu capability dictionary of the device
234
props = get_device_properties(device)
236
prop: getattr(props, prop) for prop in dir(props) if not prop.startswith("__")
240
def get_device_properties(device: Optional[_device_t] = None) -> _XpuDeviceProperties:
241
r"""Get the properties of a device.
244
device (torch.device or int or str): device for which to return the
245
properties of the device.
248
_XpuDeviceProperties: the properties of the device
251
device = _get_device_index(device, optional=True)
252
if device < 0 or device >= device_count():
253
raise AssertionError("Invalid device index")
254
return _get_device_properties(device)
257
def current_device() -> int:
258
r"""Return the index of a currently selected device."""
260
return torch._C._xpu_getDevice()
263
def _get_device(device: Union[int, str, torch.device]) -> torch.device:
264
r"""Return the torch.device type object from the passed in device.
267
device (torch.device or int or str): selected device.
269
if isinstance(device, str):
270
device = torch.device(device)
271
elif isinstance(device, int):
272
device = torch.device("xpu", device)
277
r"""Context-manager that selects a given stream.
279
All XPU kernels queued within its context will be enqueued on a selected
283
Stream (Stream): selected stream. This manager is a no-op if it's
285
.. note:: Streams are per-device.
287
cur_stream: Optional["torch.xpu.Stream"]
289
def __init__(self, stream: Optional["torch.xpu.Stream"]):
291
self.idx = _get_device_index(None, True)
296
cur_stream = self.stream
297
if cur_stream is None or self.idx == -1:
299
self.src_prev_stream = torch.xpu.current_stream(None)
302
if self.src_prev_stream.device != cur_stream.device:
303
with device(cur_stream.device):
304
self.dst_prev_stream = torch.xpu.current_stream(cur_stream.device)
305
torch.xpu.set_stream(cur_stream)
307
def __exit__(self, type: Any, value: Any, traceback: Any):
308
cur_stream = self.stream
309
if cur_stream is None or self.idx == -1:
313
if self.src_prev_stream.device != cur_stream.device:
314
torch.xpu.set_stream(self.dst_prev_stream)
315
torch.xpu.set_stream(self.src_prev_stream)
318
def stream(stream: Optional["torch.xpu.Stream"]) -> StreamContext:
319
r"""Wrap around the Context-manager StreamContext that selects a given stream.
322
stream (Stream): selected stream. This manager is a no-op if it's ``None``.
324
return StreamContext(stream)
327
def _set_stream_by_id(stream_id, device_index, device_type):
328
r"""set stream specified by the stream id, device index and device type
330
Args: stream_id (int): not visible to the user, used to assigned to the specific stream.
331
device_index (int): selected device index.
332
device_type (int): selected device type.
334
torch._C._xpu_setStream(
336
device_index=device_index,
337
device_type=device_type,
341
def set_stream(stream: Stream):
342
r"""Set the current stream.This is a wrapper API to set the stream.
343
Usage of this function is discouraged in favor of the ``stream``
347
stream (Stream): selected stream. This function is a no-op
348
if this argument is ``None``.
354
stream_id=stream.stream_id,
355
device_index=stream.device_index,
356
device_type=stream.device_type,
360
def current_stream(device: Optional[_device_t] = None) -> Stream:
361
r"""Return the currently selected :class:`Stream` for a given device.
364
device (torch.device or int, optional): selected device. Returns
365
the currently selected :class:`Stream` for the current device, given
366
by :func:`~torch.xpu.current_device`, if :attr:`device` is ``None``
370
streamdata = torch._C._xpu_getCurrentStream(
371
_get_device_index(device, optional=True)
374
stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]
378
def synchronize(device: _device_t = None) -> None:
379
r"""Wait for all kernels in all streams on a XPU device to complete.
382
device (torch.device or int, optional): device for which to synchronize.
383
It uses the current device, given by :func:`~torch.xpu.current_device`,
384
if :attr:`device` is ``None`` (default).
387
device = _get_device_index(device, optional=True)
388
return torch._C._xpu_synchronize(device)
391
def empty_cache() -> None:
392
r"""Release all unoccupied cached memory currently held by the caching
393
allocator so that those can be used in other XPU application.
396
:func:`~torch.xpu.empty_cache` doesn't increase the amount of XPU
397
memory available for PyTorch. However, it may help reduce fragmentation
398
of XPU memory in certain cases.
401
torch._C._xpu_emptyCache()
404
def _get_generator(device: torch.device) -> torch._C.Generator:
405
r"""Return the XPU Generator object for the given device.
408
device (torch.device): selected device.
412
idx = current_device()
413
return torch.xpu.default_generators[idx]
416
def _set_rng_state_offset(
417
offset: int, device: Union[int, str, torch.device] = "xpu"
419
r"""Set the random number generator state offset of the specified GPU.
422
offset (int): The desired offset
423
device (torch.device or int, optional): The device to set the RNG state.
424
Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device).
426
final_device = _get_device(device)
429
default_generator = _get_generator(final_device)
430
default_generator.set_offset(offset)
435
def _get_rng_state_offset(device: Union[int, str, torch.device] = "xpu") -> int:
436
r"""Return the random number generator state offset of the specified GPU.
439
device (torch.device or int, optional): The device to return the RNG state offset of.
440
Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device).
443
This function eagerly initializes XPU.
446
final_device = _get_device(device)
447
default_generator = _get_generator(final_device)
448
return default_generator.get_offset()
460
"default_generators",
465
"get_device_capability",
467
"get_device_properties",