4
from torch._streambase import _EventBase, _StreamBase
5
from .._utils import _dummy_type
8
if not hasattr(torch._C, "_XpuStreamBase"):
10
torch._C.__dict__["_XpuStreamBase"] = _dummy_type("_XpuStreamBase")
11
torch._C.__dict__["_XpuEventBase"] = _dummy_type("_XpuEventBase")
14
class Stream(torch._C._XpuStreamBase, _StreamBase):
15
r"""Wrapper around a XPU stream.
17
A XPU stream is a linear sequence of execution that belongs to a specific
18
device, independent from other streams.
21
device(torch.device or int, optional): a device on which to allocate
22
the stream. If :attr:`device` is ``None`` (default) or a negative
23
integer, this will use the current device.
24
priority(int, optional): priority of the stream, should be 0 or
25
negative, where negative numbers indicate higher priority. By default,
26
streams have priority 0.
29
def __new__(cls, device=None, priority=0, **kwargs):
31
if device is None or ("stream_id" in kwargs and "device_index" in kwargs):
32
return super().__new__(cls, priority=priority, **kwargs)
34
with torch.xpu.device(device):
35
return super().__new__(cls, priority=priority, **kwargs)
37
def wait_event(self, event):
38
r"""Make all future work submitted to the stream wait for an event.
41
event (torch.xpu.Event): an event to wait for.
45
def wait_stream(self, stream):
46
r"""Synchronize with another stream.
48
All future work submitted to this stream will wait until all kernels
49
submitted to a given stream at the time of call complete.
52
stream (Stream): a stream to synchronize.
54
self.wait_event(stream.record_event())
56
def record_event(self, event=None):
60
event (torch.xpu.Event, optional): event to record. If not given, a new one
72
r"""Check if all the work submitted has been completed.
75
A boolean indicating if all kernels in this stream are completed.
77
return super().query()
79
def synchronize(self):
80
r"""Wait for all the kernels in this stream to complete."""
84
def _as_parameter_(self):
85
return ctypes.c_void_p(self.sycl_queue)
88
if isinstance(o, Stream):
89
return super().__eq__(o)
93
return hash((self.sycl_queue, self.device))
96
return f"torch.xpu.Stream(device={self.device} sycl_queue={self.sycl_queue:#x})"
99
class Event(torch._C._XpuEventBase, _EventBase):
100
r"""Wrapper around a XPU event.
102
XPU events are synchronization markers that can be used to monitor the
103
device's progress, and to synchronize XPU streams.
105
The underlying XPU events are lazily initialized when the event is first
106
recorded. After creation, only streams on the same device may record the
107
event. However, streams on any device can wait on the event.
110
enable_timing (bool, optional): indicates if the event should measure time
114
def __new__(cls, enable_timing=False):
115
return super().__new__(cls, enable_timing=enable_timing)
117
def record(self, stream=None):
118
r"""Record the event in a given stream.
120
Uses ``torch.xpu.current_stream()`` if no stream is specified. The
121
stream's device must match the event's device.
124
stream = torch.xpu.current_stream()
125
super().record(stream)
127
def wait(self, stream=None):
128
r"""Make all future work submitted to the given stream wait for this event.
130
Use ``torch.xpu.current_stream()`` if no stream is specified.
133
stream = torch.xpu.current_stream()
137
r"""Check if all work currently captured by event has completed.
140
A boolean indicating if all work currently captured by event has
143
return super().query()
145
def elapsed_time(self, end_event):
146
r"""Return the time elapsed.
148
Time reported in milliseconds after the event was recorded and
149
before the end_event was recorded.
151
return super().elapsed_time(end_event)
153
def synchronize(self):
154
r"""Wait for the event to complete.
156
Waits until the completion of all work currently captured in this event.
157
This prevents the CPU thread from proceeding until the event completes.
159
super().synchronize()
162
def _as_parameter_(self):
163
return ctypes.c_void_p(self.sycl_event)
167
return f"torch.xpu.Event(sycl_event={self.sycl_event:#x})"
169
return "torch.xpu.Event(uninitialized)"