pytorch

Форк
0
/
streams.py 
169 строк · 5.3 Кб
1
import ctypes
2

3
import torch
4
from torch._streambase import _EventBase, _StreamBase
5
from .._utils import _dummy_type
6

7

8
if not hasattr(torch._C, "_XpuStreamBase"):
9
    # Define dummy base classes
10
    torch._C.__dict__["_XpuStreamBase"] = _dummy_type("_XpuStreamBase")
11
    torch._C.__dict__["_XpuEventBase"] = _dummy_type("_XpuEventBase")
12

13

14
class Stream(torch._C._XpuStreamBase, _StreamBase):
15
    r"""Wrapper around a XPU stream.
16

17
    A XPU stream is a linear sequence of execution that belongs to a specific
18
    device, independent from other streams.
19

20
    Args:
21
        device(torch.device or int, optional): a device on which to allocate
22
            the stream. If :attr:`device` is ``None`` (default) or a negative
23
            integer, this will use the current device.
24
        priority(int, optional): priority of the stream, should be 0 or
25
            negative, where negative numbers indicate higher priority. By default,
26
            streams have priority 0.
27
    """
28

29
    def __new__(cls, device=None, priority=0, **kwargs):
30
        # setting device manager is expensive, so we avoid it unless necessary
31
        if device is None or ("stream_id" in kwargs and "device_index" in kwargs):
32
            return super().__new__(cls, priority=priority, **kwargs)
33
        else:
34
            with torch.xpu.device(device):
35
                return super().__new__(cls, priority=priority, **kwargs)
36

37
    def wait_event(self, event):
38
        r"""Make all future work submitted to the stream wait for an event.
39

40
        Args:
41
            event (torch.xpu.Event): an event to wait for.
42
        """
43
        event.wait(self)
44

45
    def wait_stream(self, stream):
46
        r"""Synchronize with another stream.
47

48
        All future work submitted to this stream will wait until all kernels
49
        submitted to a given stream at the time of call complete.
50

51
        Args:
52
            stream (Stream): a stream to synchronize.
53
        """
54
        self.wait_event(stream.record_event())
55

56
    def record_event(self, event=None):
57
        r"""Record an event.
58

59
        Args:
60
            event (torch.xpu.Event, optional): event to record. If not given, a new one
61
                will be allocated.
62

63
        Returns:
64
            Recorded event.
65
        """
66
        if event is None:
67
            event = Event()
68
        event.record(self)
69
        return event
70

71
    def query(self):
72
        r"""Check if all the work submitted has been completed.
73

74
        Returns:
75
            A boolean indicating if all kernels in this stream are completed.
76
        """
77
        return super().query()
78

79
    def synchronize(self):
80
        r"""Wait for all the kernels in this stream to complete."""
81
        super().synchronize()
82

83
    @property
84
    def _as_parameter_(self):
85
        return ctypes.c_void_p(self.sycl_queue)
86

87
    def __eq__(self, o):
88
        if isinstance(o, Stream):
89
            return super().__eq__(o)
90
        return False
91

92
    def __hash__(self):
93
        return hash((self.sycl_queue, self.device))
94

95
    def __repr__(self):
96
        return f"torch.xpu.Stream(device={self.device} sycl_queue={self.sycl_queue:#x})"
97

98

99
class Event(torch._C._XpuEventBase, _EventBase):
100
    r"""Wrapper around a XPU event.
101

102
    XPU events are synchronization markers that can be used to monitor the
103
    device's progress, and to synchronize XPU streams.
104

105
    The underlying XPU events are lazily initialized when the event is first
106
    recorded. After creation, only streams on the same device may record the
107
    event. However, streams on any device can wait on the event.
108

109
    Args:
110
        enable_timing (bool, optional): indicates if the event should measure time
111
            (default: ``False``)
112
    """
113

114
    def __new__(cls, enable_timing=False):
115
        return super().__new__(cls, enable_timing=enable_timing)
116

117
    def record(self, stream=None):
118
        r"""Record the event in a given stream.
119

120
        Uses ``torch.xpu.current_stream()`` if no stream is specified. The
121
        stream's device must match the event's device.
122
        """
123
        if stream is None:
124
            stream = torch.xpu.current_stream()
125
        super().record(stream)
126

127
    def wait(self, stream=None):
128
        r"""Make all future work submitted to the given stream wait for this event.
129

130
        Use ``torch.xpu.current_stream()`` if no stream is specified.
131
        """
132
        if stream is None:
133
            stream = torch.xpu.current_stream()
134
        super().wait(stream)
135

136
    def query(self):
137
        r"""Check if all work currently captured by event has completed.
138

139
        Returns:
140
            A boolean indicating if all work currently captured by event has
141
            completed.
142
        """
143
        return super().query()
144

145
    def elapsed_time(self, end_event):
146
        r"""Return the time elapsed.
147

148
        Time reported in milliseconds after the event was recorded and
149
        before the end_event was recorded.
150
        """
151
        return super().elapsed_time(end_event)
152

153
    def synchronize(self):
154
        r"""Wait for the event to complete.
155

156
        Waits until the completion of all work currently captured in this event.
157
        This prevents the CPU thread from proceeding until the event completes.
158
        """
159
        super().synchronize()
160

161
    @property
162
    def _as_parameter_(self):
163
        return ctypes.c_void_p(self.sycl_event)
164

165
    def __repr__(self):
166
        if self.sycl_event:
167
            return f"torch.xpu.Event(sycl_event={self.sycl_event:#x})"
168
        else:
169
            return "torch.xpu.Event(uninitialized)"
170

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.