DeepSpeed

Зеркало из https://github.com/microsoft/DeepSpeed
Форк
0
/
mlu_accelerator.py 
300 строк · 9.1 Кб
1
# Copyright (c) Microsoft Corporation.
2
# SPDX-License-Identifier: Apache-2.0
3

4
# DeepSpeed Team
5
import importlib
6
import inspect
7
import functools
8

9
from .abstract_accelerator import DeepSpeedAccelerator
10
import torch
11
# During setup stage torch may not be installed, pass on no torch will
12
# allow op builder related API to be executed.
13

14

15
class MLU_Accelerator(DeepSpeedAccelerator):
16

17
    def __init__(self):
18
        self._name = 'mlu'
19
        self._communication_backend_name = 'cncl'
20
        self._compile_backend = "inductor"
21
        self.class_dict = None
22

23
    def is_synchronized_device(self):
24
        return False
25

26
    def use_host_timers(self):
27
        return self.is_synchronized_device()
28

29
    def resolves_data_dependency(self):
30
        return self.is_synchronized_device()
31

32
    def handles_memory_backpressure(self):
33
        return self.is_synchronized_device()
34

35
    # Device APIs
36
    def device_name(self, device_index=None):
37
        if device_index == None:
38
            return 'mlu'
39
        return 'mlu:{}'.format(device_index)
40

41
    def device(self, device_index=None):
42
        return torch.mlu.device(device_index)
43

44
    def set_device(self, device_index):
45
        torch.mlu.set_device(device_index)
46

47
    def current_device(self):
48
        return torch.mlu.current_device()
49

50
    def current_device_name(self):
51
        return 'mlu:{}'.format(torch.mlu.current_device())
52

53
    def device_count(self):
54
        return torch.mlu.device_count()
55

56
    def synchronize(self, device_index=None):
57
        return torch.mlu.synchronize(device_index)
58

59
    # RNG APIs
60
    def random(self):
61
        return torch.random
62

63
    def set_rng_state(self, new_state, device_index=None):
64
        if device_index is None:
65
            return torch.mlu.set_rng_state(new_state)
66

67
        return torch.mlu.set_rng_state(new_state, device_index)
68

69
    def get_rng_state(self, device_index=None):
70
        if device_index is None:
71
            return torch.mlu.get_rng_state()
72

73
        return torch.mlu.get_rng_state(device_index)
74

75
    def manual_seed(self, seed):
76
        return torch.mlu.manual_seed(seed)
77

78
    def manual_seed_all(self, seed):
79
        return torch.mlu.manual_seed_all(seed)
80

81
    def initial_seed(self, seed):
82
        return torch.mlu.initial_seed(seed)
83

84
    def default_generator(self, device_index):
85
        return torch.mlu.default_generators[device_index]
86

87
    # Streams/Events
88
    @property
89
    def Stream(self):
90
        return torch.mlu.Stream
91

92
    def stream(self, stream):
93
        return torch.mlu.stream(stream)
94

95
    def current_stream(self, device_index=None):
96
        return torch.mlu.current_stream(device_index)
97

98
    def default_stream(self, device_index=None):
99
        return torch.mlu.default_stream(device_index)
100

101
    @property
102
    def Event(self):
103
        return torch.mlu.Event
104

105
    # Memory management
106
    def empty_cache(self):
107
        return torch.mlu.empty_cache()
108

109
    def memory_allocated(self, device_index=None):
110
        return torch.mlu.memory_allocated(device_index)
111

112
    def max_memory_allocated(self, device_index=None):
113
        return torch.mlu.max_memory_allocated(device_index)
114

115
    def reset_max_memory_allocated(self, device_index=None):
116
        return torch.mlu.reset_max_memory_allocated(device_index)
117

118
    def memory_cached(self, device_index=None):
119
        return torch.mlu.memory_cached(device_index)
120

121
    def max_memory_cached(self, device_index=None):
122
        return torch.mlu.max_memory_cached(device_index)
123

124
    def reset_max_memory_cached(self, device_index=None):
125
        return torch.mlu.reset_max_memory_cached(device_index)
126

127
    def memory_stats(self, device_index=None):
128
        if hasattr(torch.mlu, 'memory_stats'):
129
            return torch.mlu.memory_stats(device_index)
130

131
    def reset_peak_memory_stats(self, device_index=None):
132
        if hasattr(torch.mlu, 'reset_peak_memory_stats'):
133
            return torch.mlu.reset_peak_memory_stats(device_index)
134

135
    def memory_reserved(self, device_index=None):
136
        if hasattr(torch.mlu, 'memory_reserved'):
137
            return torch.mlu.memory_reserved(device_index)
138

139
    def max_memory_reserved(self, device_index=None):
140
        if hasattr(torch.mlu, 'max_memory_reserved'):
141
            return torch.mlu.max_memory_reserved(device_index)
142

143
    def total_memory(self, device_index=None):
144
        return torch.mlu.get_device_properties(device_index).total_memory
145

146
    def available_memory(self, device_index=None):
147
        return self.total_memory(device_index) - self.memory_allocated(device_index)
148

149
    # Data types
150
    def is_bf16_supported(self):
151
        return torch.mlu.is_bf16_supported()
152

153
    def is_fp16_supported(self):
154
        return True
155

156
    def supported_dtypes(self):
157
        supported_dtypes = [torch.float]
158
        if self.is_fp16_supported():
159
            supported_dtypes.append(torch.half)
160
        if self.is_bf16_supported():
161
            supported_dtypes.append(torch.bfloat16)
162
        return supported_dtypes
163

164
    # Misc
165
    def amp(self):
166
        if hasattr(torch.mlu, 'amp'):
167
            return torch.mlu.amp
168
        return None
169

170
    def is_available(self):
171
        return torch.mlu.is_available()
172

173
    def range_push(self, msg):
174
        if hasattr(torch.mlu.cnpx, 'range_push'):
175
            return torch.mlu.cnpx.range_push(msg)
176

177
    def range_pop(self):
178
        if hasattr(torch.mlu.cnpx, 'range_pop'):
179
            return torch.mlu.cnpx.range_pop()
180

181
    def lazy_call(self, callback):
182
        return torch.mlu._lazy_call(callback)
183

184
    def communication_backend_name(self):
185
        return self._communication_backend_name
186

187
    def is_triton_supported(self):
188
        return True
189

190
    # Graph operations
191
    def create_graph(self):
192
        torch.mlu.MLUGraph()
193

194
    def capture_to_graph(self, graph, pool=None, stream=None):
195
        return torch.mlu.graph(graph, pool, stream)
196

197
    def replay_graph(self, graph):
198
        graph.replay()
199
        return
200

201
    # Tensor operations
202

203
    @property
204
    def BFloat16Tensor(self):
205
        return functools.partial(torch.tensor, dtype=torch.bfloat16, device='mlu')
206

207
    @property
208
    def ByteTensor(self):
209
        return functools.partial(torch.tensor, dtype=torch.uint8, device='mlu')
210

211
    @property
212
    def DoubleTensor(self):
213
        return functools.partial(torch.tensor, dtype=torch.double, device='mlu')
214

215
    @property
216
    def FloatTensor(self):
217
        return functools.partial(torch.tensor, dtype=torch.float, device='mlu')
218

219
    @property
220
    def HalfTensor(self):
221
        return functools.partial(torch.tensor, dtype=torch.half, device='mlu')
222

223
    @property
224
    def IntTensor(self):
225
        return functools.partial(torch.tensor, dtype=torch.int, device='mlu')
226

227
    @property
228
    def LongTensor(self):
229
        return functools.partial(torch.tensor, dtype=torch.long, device='mlu')
230

231
    def pin_memory(self, tensor):
232
        return tensor.pin_memory()
233

234
    def is_pinned(self, tensor):
235
        return tensor.is_pinned()
236

237
    def on_accelerator(self, tensor):
238
        device_str = str(tensor.device)
239
        if device_str.startswith('mlu:'):
240
            return True
241
        else:
242
            return False
243

244
    def op_builder_dir(self):
245
        try:
246
            # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
247
            # if successful this also means we're doing a local install and not JIT compile path
248
            from op_builder import __deepspeed__  # noqa: F401 # type: ignore
249
            return "op_builder.mlu"
250
        except ImportError:
251
            return "deepspeed.ops.op_builder.mlu"
252

253
    def _lazy_init_class_dict(self):
254
        if self.class_dict:
255
            return
256

257
        op_builder_module = importlib.import_module(self.op_builder_dir())
258

259
        # get op builder class from op_builder/mlu/__init__.py
260
        self.class_dict = {}
261
        for class_name, class_obj in inspect.getmembers(op_builder_module, inspect.isclass):
262
            self.class_dict[class_name] = class_obj
263

264
    # create an instance of op builder and return, name specified by class_name
265
    def create_op_builder(self, class_name):
266
        builder_class = self.get_op_builder(class_name)
267
        return builder_class()
268

269
    # return an op builder class, name specified by class_name
270
    def get_op_builder(self, class_name):
271
        self._lazy_init_class_dict()
272
        if class_name in self.class_dict:
273
            return self.class_dict[class_name]
274
        else:
275
            return self.class_dict['NotImplementedBuilder']
276

277
    def build_extension(self):
278
        from torch.utils.cpp_extension import BuildExtension
279
        return BuildExtension
280

281
    def export_envs(self):
282
        return ['NEUWARE_HOME', 'CNCL', 'LD_LIBRARY', 'PATH']
283

284
    def visible_devices_envs(self):
285
        return ['MLU_VISIBLE_DEVICES']
286

287
    def set_visible_devices_envs(self, current_env, local_accelerator_ids):
288
        for env in self.visible_devices_envs():
289
            current_env[env] = ",".join(map(str, local_accelerator_ids))
290

291
    def get_compile_backend(self):
292
        return self._compile_backend
293

294
    def set_compile_backend(self, backend):
295
        supported_backends = torch._dynamo.list_backends(exclude_tags=())
296
        if backend in supported_backends:
297
            self._compile_backend = backend
298
        else:
299
            raise ValueError(
300
                f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends }")
301

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.