DeepSpeed

Зеркало из https://github.com/microsoft/DeepSpeed
Форк
0
/
xpu_accelerator.py 
318 строк · 10.1 Кб
1
# Copyright (c) Microsoft Corporation.
2
# SPDX-License-Identifier: Apache-2.0
3

4
# DeepSpeed Team
5

6
import torch
7
from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator
8
import intel_extension_for_pytorch as ipex  # noqa: F401 # type: ignore
9
import oneccl_bindings_for_pytorch  # noqa: F401 # type: ignore
10
import functools
11

12
import importlib
13
import inspect
14

15

16
class XPU_Accelerator(DeepSpeedAccelerator):
17

18
    def __init__(self):
19
        self._name = 'xpu'
20
        self._communication_backend_name = 'ccl'
21
        self._compile_backend = "inductor"
22
        self.aligned_tensors = []
23
        self.class_dict = None
24

25
    def is_synchronized_device(self):
26
        return False
27

28
    def use_host_timers(self):
29
        # WA XPU event will be consolidated in 2.5
30
        if ipex.__version__ < '2.5':
31
            return True
32
        else:
33
            return self.is_synchronized_device()
34

35
    def resolves_data_dependency(self):
36
        return self.is_synchronized_device()
37

38
    def handles_memory_backpressure(self):
39
        return self.is_synchronized_device()
40

41
    # Device APIs
42
    def device_name(self, device_index=None):
43
        if device_index == None:
44
            return 'xpu'
45
        return 'xpu:{}'.format(device_index)
46

47
    def device(self, device_index=None):
48
        return torch.xpu.device(device_index)
49

50
    def set_device(self, device_index):
51
        torch.xpu.set_device(device_index)
52

53
    def current_device(self):
54
        return torch.xpu.current_device()
55

56
    def current_device_name(self):
57
        return 'xpu:{}'.format(torch.xpu.current_device())
58

59
    def device_count(self):
60
        return torch.xpu.device_count()
61

62
    def synchronize(self, device_index=None):
63
        return torch.xpu.synchronize(device_index)
64

65
    # RNG APIs
66
    def random(self):
67
        return torch.xpu.random
68

69
    def set_rng_state(self, new_state, device_index=None):
70
        if device_index == None:
71
            return torch.xpu.set_rng_state(new_state)
72
        return torch.xpu.set_rng_state(new_state, device_index)
73

74
    def get_rng_state(self, device_index=None):
75
        if device_index == None:
76
            return torch.xpu.get_rng_state()
77
        return torch.xpu.get_rng_state(device_index)
78

79
    def manual_seed(self, seed):
80
        return torch.xpu.manual_seed(seed)
81

82
    def manual_seed_all(self, seed):
83
        return torch.xpu.manual_seed_all(seed)
84

85
    def initial_seed(self):
86
        return torch.xpu.initial_seed()
87

88
    def default_generator(self, device_index):
89
        return torch.xpu.default_generators[device_index]
90

91
    # Streams/Events
92
    @property
93
    def Stream(self):
94
        return torch.xpu.Stream
95

96
    def stream(self, stream):
97
        return torch.xpu.stream(stream)
98

99
    def current_stream(self, device_index=None):
100
        return torch.xpu.current_stream(device_index)
101

102
    def default_stream(self, device_index=None):
103
        # torch.xpu does not support the sync behavior of default stream as cuda
104
        # use current_stream as workaround
105
        # see https://pytorch.org/docs/stable/notes/cuda.html#cuda-streams
106
        return torch.xpu.current_stream(device_index)
107

108
    @property
109
    def Event(self):
110
        return torch.xpu.Event
111

112
    # Memory management
113
    def empty_cache(self):
114
        return torch.xpu.empty_cache()
115

116
    def memory_allocated(self, device_index=None):
117
        return torch.xpu.memory_allocated(device_index)
118

119
    def max_memory_allocated(self, device_index=None):
120
        return torch.xpu.max_memory_allocated(device_index)
121

122
    def reset_max_memory_allocated(self, device_index=None):
123
        return torch.xpu.reset_max_memory_allocated(device_index)
124

125
    def memory_cached(self, device_index=None):
126
        return torch.xpu.memory_reserved(device_index)
127

128
    def max_memory_cached(self, device_index=None):
129
        return torch.xpu.max_memory_reserved(device_index)
130

131
    def reset_max_memory_cached(self, device_index=None):
132
        return torch.xpu.reset_max_memory_reserved(device_index)
133

134
    def memory_stats(self, device_index=None):
135
        return torch.xpu.memory_stats(device_index)
136

137
    def reset_peak_memory_stats(self, device_index=None):
138
        return torch.xpu.reset_peak_memory_stats(device_index)
139

140
    def memory_reserved(self, device_index=None):
141
        return torch.xpu.memory_reserved(device_index)
142

143
    def max_memory_reserved(self, device_index=None):
144
        return torch.xpu.max_memory_reserved(device_index)
145

146
    def total_memory(self, device_index=None):
147
        return torch.xpu.get_device_properties(device_index).total_memory
148

149
    def available_memory(self, device_index=None):
150
        return self.total_memory(device_index) - self.memory_allocated(device_index)
151

152
    # Misc
153
    def amp(self):
154
        return torch.xpu.amp
155

156
    def is_available(self):
157
        return torch.xpu.is_available()
158

159
    def range_push(self, msg):
160
        # TODO itt is currently not supported yet
161
        # return torch.profiler.itt.range_push(msg)
162
        return
163

164
    def range_pop(self):
165
        # TODO itt is currently not supported yet
166
        # return torch.profiler.itt.range_pop()
167
        return
168

169
    def lazy_call(self, callback):
170
        if hasattr(torch.xpu, "_lazy_call"):
171
            return torch.xpu._lazy_call(callback)
172
        else:
173
            return torch.xpu.lazy_init._lazy_call(callback)
174

175
    def communication_backend_name(self):
176
        return self._communication_backend_name
177

178
    def is_triton_supported(self):
179
        return False
180

181
    # Graph operations
182
    def create_graph(self):
183
        return None
184

185
    def capture_to_graph(self, graph, pool=None, stream=None):
186
        from deepspeed.runtime.utils import noop_context
187
        return noop_context()
188

189
    def replay_graph(self, graph):
190
        return
191

192
    # Data types
193
    def is_bf16_supported(self):
194
        return True
195

196
    def is_fp16_supported(self):
197
        return True
198

199
    def supported_dtypes(self):
200
        return [torch.float, torch.half, torch.bfloat16]
201

202
    # Tensor operations
203

204
    @property
205
    def BFloat16Tensor(self):
206
        return functools.partial(torch.tensor, dtype=torch.bfloat16, device=self._name)
207

208
    @property
209
    def ByteTensor(self):
210
        return functools.partial(torch.tensor, dtype=torch.uint8, device=self._name)
211

212
    @property
213
    def DoubleTensor(self):
214
        return functools.partial(torch.tensor, dtype=torch.double, device=self._name)
215

216
    @property
217
    def FloatTensor(self):
218
        return functools.partial(torch.tensor, dtype=torch.float, device=self._name)
219

220
    @property
221
    def HalfTensor(self):
222
        return functools.partial(torch.tensor, dtype=torch.half, device=self._name)
223

224
    @property
225
    def IntTensor(self):
226
        return functools.partial(torch.tensor, dtype=torch.int, device=self._name)
227

228
    @property
229
    def LongTensor(self):
230
        return functools.partial(torch.tensor, dtype=torch.long, device=self._name)
231

232
    def pin_memory(self, tensor, align_bytes=1):
233
        if align_bytes == 1:
234
            return tensor.pin_memory(device=self.current_device_name())
235
        elif align_bytes == 0:
236
            from deepspeed.ops.op_builder.xpu import AsyncIOBuilder
237
            self.aio_handle = AsyncIOBuilder().load().aio_handle(128 * 1024, 8, False, False, False)
238
            aligned_t = self.aio_handle.new_cpu_locked_tensor(tensor.numel(), tensor)
239
            aligned_t = aligned_t[:tensor.numel()].copy_(tensor)
240
            self.aligned_tensors.append([aligned_t.data_ptr(), aligned_t[-1].data_ptr()])
241
            return aligned_t
242

243
    def is_pinned(self, tensor):
244
        if tensor.is_pinned(device=self.current_device_name()):
245
            return True
246
        else:
247
            for begin, end in self.aligned_tensors:
248
                if begin <= tensor.data_ptr() and tensor.data_ptr() <= end:
249
                    return True
250
        return False
251

252
    def op_builder_dir(self):
253
        try:
254
            # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
255
            # if successful this also means we're doing a local install and not JIT compile path
256
            from op_builder import __deepspeed__  # noqa: F401 # type: ignore
257
            return "op_builder.xpu"
258
        except ImportError:
259
            return "deepspeed.ops.op_builder.xpu"
260

261
    def on_accelerator(self, tensor):
262
        device_str = str(tensor.device)
263
        if device_str.startswith('xpu:'):
264
            return True
265
        else:
266
            return False
267

268
    def _lazy_init_class_dict(self):
269
        if self.class_dict:
270
            return
271

272
        op_builder_module = importlib.import_module(self.op_builder_dir())
273

274
        # get op builder class from op_builder/xpu/__init__.py
275
        self.class_dict = {}
276
        for class_name, class_obj in inspect.getmembers(op_builder_module, inspect.isclass):
277
            self.class_dict[class_name] = class_obj
278

279
    # create an instance of op builder and return, name specified by class_name
280
    def create_op_builder(self, class_name):
281
        builder_class = self.get_op_builder(class_name)
282
        return builder_class()
283

284
    # return an op builder class, name specified by class_name
285
    def get_op_builder(self, class_name):
286
        self._lazy_init_class_dict()
287
        if class_name in self.class_dict:
288
            return self.class_dict[class_name]
289
        else:
290
            return self.class_dict['NotImplementedBuilder']
291

292
    def build_extension(self):
293
        try:
294
            from intel_extension_for_pytorch.xpu.cpp_extension import DpcppBuildExtension
295
        except ImportError:
296
            from intel_extension_for_pytorch.xpu.utils import DpcppBuildExtension
297
        return DpcppBuildExtension
298

299
    def export_envs(self):
300
        return []
301

302
    def visible_devices_envs(self):
303
        return ['ZE_AFFINITY_MASK']
304

305
    def set_visible_devices_envs(self, current_env, local_accelerator_ids):
306
        for env in self.visible_devices_envs():
307
            current_env[env] = ",".join(map(str, local_accelerator_ids))
308

309
    def get_compile_backend(self):
310
        return self._compile_backend
311

312
    def set_compile_backend(self, backend):
313
        supported_backends = torch._dynamo.list_backends(exclude_tags=())
314
        if backend in supported_backends:
315
            self._compile_backend = backend
316
        else:
317
            raise ValueError(
318
                f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}")
319

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.