DeepSpeed

Зеркало из https://github.com/microsoft/DeepSpeed
Форк
0
/
cuda_accelerator.py 
382 строки · 12.5 Кб
1
# Copyright (c) Microsoft Corporation.
2
# SPDX-License-Identifier: Apache-2.0
3

4
# DeepSpeed Team
5

6
import functools
7
import os
8
import pkgutil
9
import importlib
10
import sys
11

12
from .abstract_accelerator import DeepSpeedAccelerator
13
# During setup stage torch may not be installed, pass on no torch will
14
# allow op builder related API to be executed.
15
try:
16
    import torch.cuda
17
except ImportError:
18
    pass
19

20
# Delay import pynvml to avoid import error when CUDA is not available
21
pynvml = None
22

23

24
class CUDA_Accelerator(DeepSpeedAccelerator):
25

26
    def __init__(self):
27
        self._name = 'cuda'
28
        self._communication_backend_name = 'nccl' if sys.platform != 'win32' else 'gloo'
29
        self._compile_backend = "inductor"
30
        if pynvml is None:
31
            self._init_pynvml()
32

33
    def _init_pynvml(self):
34
        global pynvml
35
        try:
36
            import pynvml
37
        except ImportError:
38
            return
39
        try:
40
            pynvml.nvmlInit()
41
        except pynvml.NVMLError:
42
            pynvml = None
43
            return
44

45
    def is_synchronized_device(self):
46
        return False
47

48
    def use_host_timers(self):
49
        return self.is_synchronized_device()
50

51
    def resolves_data_dependency(self):
52
        return self.is_synchronized_device()
53

54
    def handles_memory_backpressure(self):
55
        return self.is_synchronized_device()
56

57
    # Device APIs
58
    def device_name(self, device_index=None):
59
        if device_index is None:
60
            return 'cuda'
61
        return 'cuda:{}'.format(device_index)
62

63
    def device(self, device_index=None):
64
        return torch.cuda.device(device_index)
65

66
    def set_device(self, device_index):
67
        torch.cuda.set_device(device_index)
68

69
    def current_device(self):
70
        return torch.cuda.current_device()
71

72
    def current_device_name(self):
73
        return 'cuda:{}'.format(torch.cuda.current_device())
74

75
    def device_count(self):
76
        return torch.cuda.device_count()
77

78
    def synchronize(self, device_index=None):
79
        return torch.cuda.synchronize(device_index)
80

81
    # RNG APIs
82
    def random(self):
83
        return torch.random
84

85
    def set_rng_state(self, new_state, device_index=None):
86
        if device_index is None:
87
            return torch.cuda.set_rng_state(new_state)
88

89
        return torch.cuda.set_rng_state(new_state, device_index)
90

91
    def get_rng_state(self, device_index=None):
92
        if device_index is None:
93
            return torch.cuda.get_rng_state()
94

95
        return torch.cuda.get_rng_state(device_index)
96

97
    def manual_seed(self, seed):
98
        return torch.cuda.manual_seed(seed)
99

100
    def manual_seed_all(self, seed):
101
        return torch.cuda.manual_seed_all(seed)
102

103
    def initial_seed(self):
104
        return torch.cuda.initial_seed()
105

106
    def default_generator(self, device_index):
107
        return torch.cuda.default_generators[device_index]
108

109
    # Streams/Events
110
    @property
111
    def Stream(self):
112
        return torch.cuda.Stream
113

114
    def stream(self, stream):
115
        return torch.cuda.stream(stream)
116

117
    def current_stream(self, device_index=None):
118
        return torch.cuda.current_stream(device_index)
119

120
    def default_stream(self, device_index=None):
121
        return torch.cuda.default_stream(device_index)
122

123
    @property
124
    def Event(self):
125
        return torch.cuda.Event
126

127
    # Memory management
128
    def empty_cache(self):
129
        return torch.cuda.empty_cache()
130

131
    def memory_allocated(self, device_index=None):
132
        return torch.cuda.memory_allocated(device_index)
133

134
    def max_memory_allocated(self, device_index=None):
135
        return torch.cuda.max_memory_allocated(device_index)
136

137
    def reset_max_memory_allocated(self, device_index=None):
138
        return torch.cuda.reset_max_memory_allocated(device_index)
139

140
    def memory_cached(self, device_index=None):
141
        return torch.cuda.memory_cached(device_index)
142

143
    def max_memory_cached(self, device_index=None):
144
        return torch.cuda.max_memory_cached(device_index)
145

146
    def reset_max_memory_cached(self, device_index=None):
147
        return torch.cuda.reset_max_memory_cached(device_index)
148

149
    def memory_stats(self, device_index=None):
150
        if hasattr(torch.cuda, 'memory_stats'):
151
            return torch.cuda.memory_stats(device_index)
152

153
    def reset_peak_memory_stats(self, device_index=None):
154
        if hasattr(torch.cuda, 'reset_peak_memory_stats'):
155
            return torch.cuda.reset_peak_memory_stats(device_index)
156

157
    def memory_reserved(self, device_index=None):
158
        if hasattr(torch.cuda, 'memory_reserved'):
159
            return torch.cuda.memory_reserved(device_index)
160

161
    def max_memory_reserved(self, device_index=None):
162
        if hasattr(torch.cuda, 'max_memory_reserved'):
163
            return torch.cuda.max_memory_reserved(device_index)
164

165
    def total_memory(self, device_index=None):
166
        return torch.cuda.get_device_properties(device_index).total_memory
167

168
    def _get_nvml_gpu_id(self, torch_gpu_id):
169
        """
170
        credit: https://discuss.pytorch.org/t/making-pynvml-match-torch-device-ids-cuda-visible-devices/103020
171

172
        Remap torch device id to nvml device id, respecting CUDA_VISIBLE_DEVICES.
173

174
        If the latter isn't set return the same id
175
        """
176
        # if CUDA_VISIBLE_DEVICES is used automagically remap the id since pynvml ignores this env var
177
        if "CUDA_VISIBLE_DEVICES" in os.environ:
178
            ids = list(map(int, os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",")))
179
            return ids[torch_gpu_id]  # remap
180
        else:
181
            return torch_gpu_id
182

183
    def available_memory(self, device_index=None):
184
        if pynvml:
185
            if device_index is None:
186
                device_index = self.current_device()
187
            handle = pynvml.nvmlDeviceGetHandleByIndex(self._get_nvml_gpu_id(device_index))
188
            info = pynvml.nvmlDeviceGetMemoryInfo(handle)
189
            return info.free
190
        else:
191
            return self.total_memory(device_index) - self.memory_allocated(device_index)
192

193
    # Data types
194
    def is_bf16_supported(self):
195
        if not torch.cuda.is_available():
196
            return True
197
        return torch.cuda.is_bf16_supported()
198

199
    def is_fp16_supported(self):
200
        if not torch.cuda.is_available():
201
            return True
202
        # See https://docs.nvidia.com/deeplearning/tensorrt/support-matrix/index.html#hardware-precision-matrix
203
        # FP16 on compute capability 6.x is deprecated
204
        allow_deprecated_fp16 = os.environ.get('DS_ALLOW_DEPRECATED_FP16', '0') == '1'
205
        major, _ = torch.cuda.get_device_capability()
206
        if major >= 7:
207
            return True
208
        elif major == 6 and allow_deprecated_fp16:
209
            return True
210
        else:
211
            return False
212

213
    def supported_dtypes(self):
214
        supported_dtypes = [torch.float]
215
        if self.is_fp16_supported():
216
            supported_dtypes.append(torch.half)
217
        if self.is_bf16_supported():
218
            supported_dtypes.append(torch.bfloat16)
219
        return supported_dtypes
220

221
    # Misc
222
    def amp(self):
223
        if hasattr(torch.cuda, 'amp'):
224
            return torch.cuda.amp
225
        return None
226

227
    def is_available(self):
228
        return torch.cuda.is_available()
229

230
    def range_push(self, msg):
231
        if hasattr(torch.cuda.nvtx, 'range_push'):
232
            return torch.cuda.nvtx.range_push(msg)
233

234
    def range_pop(self):
235
        if hasattr(torch.cuda.nvtx, 'range_pop'):
236
            return torch.cuda.nvtx.range_pop()
237

238
    def lazy_call(self, callback):
239
        return torch.cuda._lazy_call(callback)
240

241
    def communication_backend_name(self):
242
        return self._communication_backend_name
243

244
    def is_triton_supported(self):
245
        major, _ = torch.cuda.get_device_capability()
246
        if major >= 8:
247
            return True
248
        else:
249
            return False
250

251
    # Graph operations
252
    def create_graph(self):
253
        return torch.cuda.CUDAGraph()
254

255
    def capture_to_graph(self, graph, pool=None, stream=None):
256
        return torch.cuda.graph(graph, pool, stream)
257

258
    def replay_graph(self, graph):
259
        graph.replay()
260
        return
261

262
    # Tensor operations
263

264
    @property
265
    def BFloat16Tensor(self):
266
        return functools.partial(torch.tensor, dtype=torch.bfloat16, device='cuda')
267

268
    @property
269
    def ByteTensor(self):
270
        return functools.partial(torch.tensor, dtype=torch.uint8, device='cuda')
271

272
    @property
273
    def DoubleTensor(self):
274
        return functools.partial(torch.tensor, dtype=torch.double, device='cuda')
275

276
    @property
277
    def FloatTensor(self):
278
        return functools.partial(torch.tensor, dtype=torch.float, device='cuda')
279

280
    @property
281
    def HalfTensor(self):
282
        return functools.partial(torch.tensor, dtype=torch.half, device='cuda')
283

284
    @property
285
    def IntTensor(self):
286
        return functools.partial(torch.tensor, dtype=torch.int, device='cuda')
287

288
    @property
289
    def LongTensor(self):
290
        return functools.partial(torch.tensor, dtype=torch.long, device='cuda')
291

292
    def pin_memory(self, tensor, align_bytes=1):
293
        return tensor.pin_memory()
294

295
    def is_pinned(self, tensor):
296
        return tensor.is_pinned()
297

298
    def on_accelerator(self, tensor):
299
        device_str = str(tensor.device)
300
        if device_str.startswith('cuda:'):
301
            return True
302
        else:
303
            return False
304

305
    def op_builder_dir(self):
306
        try:
307
            # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
308
            # if successful this also means we're doing a local install and not JIT compile path
309
            from op_builder import __deepspeed__  # noqa: F401 # type: ignore
310
            return "op_builder"
311
        except ImportError:
312
            return "deepspeed.ops.op_builder"
313

314
    # dict that holds class name <--> class type mapping i.e.
315
    # 'AsyncIOBuilder': <class 'op_builder.async_io.AsyncIOBuilder'>
316
    # this dict will be filled at init stage
317
    class_dict = None
318

319
    def _lazy_init_class_dict(self):
320
        if self.class_dict is not None:
321
            return
322
        else:
323
            self.class_dict = {}
324
            # begin initialize for create_op_builder()
325
            # put all valid class name <--> class type mapping into class_dict
326
            op_builder_dir = self.op_builder_dir()
327
            op_builder_module = importlib.import_module(op_builder_dir)
328
            op_builder_absolute_path = os.path.dirname(op_builder_module.__file__)
329
            for _, module_name, _ in pkgutil.iter_modules([op_builder_absolute_path]):
330
                # avoid self references,
331
                # skip sub_directories which contains ops for other backend(cpu, npu, etc.).
332
                if module_name != 'all_ops' and module_name != 'builder' and not os.path.isdir(
333
                        os.path.join(op_builder_absolute_path, module_name)):
334
                    module = importlib.import_module("{}.{}".format(op_builder_dir, module_name))
335
                    for member_name in module.__dir__():
336
                        if member_name.endswith(
337
                                'Builder'
338
                        ) and member_name != "OpBuilder" and member_name != "CUDAOpBuilder" and member_name != "TorchCPUOpBuilder":  # avoid abstract classes
339
                            if not member_name in self.class_dict:
340
                                self.class_dict[member_name] = getattr(module, member_name)
341
            # end initialize for create_op_builder()
342

343
    # create an instance of op builder and return, name specified by class_name
344
    def create_op_builder(self, class_name):
345
        self._lazy_init_class_dict()
346
        if class_name in self.class_dict:
347
            return self.class_dict[class_name]()
348
        else:
349
            return None
350

351
    # return an op builder class, name specified by class_name
352
    def get_op_builder(self, class_name):
353
        self._lazy_init_class_dict()
354
        if class_name in self.class_dict:
355
            return self.class_dict[class_name]
356
        else:
357
            return None
358

359
    def build_extension(self):
360
        from torch.utils.cpp_extension import BuildExtension
361
        return BuildExtension
362

363
    def export_envs(self):
364
        return ['NCCL']
365

366
    def visible_devices_envs(self):
367
        return ['CUDA_VISIBLE_DEVICES']
368

369
    def set_visible_devices_envs(self, current_env, local_accelerator_ids):
370
        for env in self.visible_devices_envs():
371
            current_env[env] = ",".join(map(str, local_accelerator_ids))
372

373
    def get_compile_backend(self):
374
        return self._compile_backend
375

376
    def set_compile_backend(self, backend):
377
        supported_backends = torch._dynamo.list_backends(exclude_tags=())
378
        if backend in supported_backends:
379
            self._compile_backend = backend
380
        else:
381
            raise ValueError(
382
                f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}")
383

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.