DeepSpeed
Зеркало из https://github.com/microsoft/DeepSpeed
382 строки · 12.5 Кб
1# Copyright (c) Microsoft Corporation.
2# SPDX-License-Identifier: Apache-2.0
3
4# DeepSpeed Team
5
6import functools7import os8import pkgutil9import importlib10import sys11
12from .abstract_accelerator import DeepSpeedAccelerator13# During setup stage torch may not be installed, pass on no torch will
14# allow op builder related API to be executed.
15try:16import torch.cuda17except ImportError:18pass19
20# Delay import pynvml to avoid import error when CUDA is not available
21pynvml = None22
23
24class CUDA_Accelerator(DeepSpeedAccelerator):25
26def __init__(self):27self._name = 'cuda'28self._communication_backend_name = 'nccl' if sys.platform != 'win32' else 'gloo'29self._compile_backend = "inductor"30if pynvml is None:31self._init_pynvml()32
33def _init_pynvml(self):34global pynvml35try:36import pynvml37except ImportError:38return39try:40pynvml.nvmlInit()41except pynvml.NVMLError:42pynvml = None43return44
45def is_synchronized_device(self):46return False47
48def use_host_timers(self):49return self.is_synchronized_device()50
51def resolves_data_dependency(self):52return self.is_synchronized_device()53
54def handles_memory_backpressure(self):55return self.is_synchronized_device()56
57# Device APIs58def device_name(self, device_index=None):59if device_index is None:60return 'cuda'61return 'cuda:{}'.format(device_index)62
63def device(self, device_index=None):64return torch.cuda.device(device_index)65
66def set_device(self, device_index):67torch.cuda.set_device(device_index)68
69def current_device(self):70return torch.cuda.current_device()71
72def current_device_name(self):73return 'cuda:{}'.format(torch.cuda.current_device())74
75def device_count(self):76return torch.cuda.device_count()77
78def synchronize(self, device_index=None):79return torch.cuda.synchronize(device_index)80
81# RNG APIs82def random(self):83return torch.random84
85def set_rng_state(self, new_state, device_index=None):86if device_index is None:87return torch.cuda.set_rng_state(new_state)88
89return torch.cuda.set_rng_state(new_state, device_index)90
91def get_rng_state(self, device_index=None):92if device_index is None:93return torch.cuda.get_rng_state()94
95return torch.cuda.get_rng_state(device_index)96
97def manual_seed(self, seed):98return torch.cuda.manual_seed(seed)99
100def manual_seed_all(self, seed):101return torch.cuda.manual_seed_all(seed)102
103def initial_seed(self):104return torch.cuda.initial_seed()105
106def default_generator(self, device_index):107return torch.cuda.default_generators[device_index]108
109# Streams/Events110@property111def Stream(self):112return torch.cuda.Stream113
114def stream(self, stream):115return torch.cuda.stream(stream)116
117def current_stream(self, device_index=None):118return torch.cuda.current_stream(device_index)119
120def default_stream(self, device_index=None):121return torch.cuda.default_stream(device_index)122
123@property124def Event(self):125return torch.cuda.Event126
127# Memory management128def empty_cache(self):129return torch.cuda.empty_cache()130
131def memory_allocated(self, device_index=None):132return torch.cuda.memory_allocated(device_index)133
134def max_memory_allocated(self, device_index=None):135return torch.cuda.max_memory_allocated(device_index)136
137def reset_max_memory_allocated(self, device_index=None):138return torch.cuda.reset_max_memory_allocated(device_index)139
140def memory_cached(self, device_index=None):141return torch.cuda.memory_cached(device_index)142
143def max_memory_cached(self, device_index=None):144return torch.cuda.max_memory_cached(device_index)145
146def reset_max_memory_cached(self, device_index=None):147return torch.cuda.reset_max_memory_cached(device_index)148
149def memory_stats(self, device_index=None):150if hasattr(torch.cuda, 'memory_stats'):151return torch.cuda.memory_stats(device_index)152
153def reset_peak_memory_stats(self, device_index=None):154if hasattr(torch.cuda, 'reset_peak_memory_stats'):155return torch.cuda.reset_peak_memory_stats(device_index)156
157def memory_reserved(self, device_index=None):158if hasattr(torch.cuda, 'memory_reserved'):159return torch.cuda.memory_reserved(device_index)160
161def max_memory_reserved(self, device_index=None):162if hasattr(torch.cuda, 'max_memory_reserved'):163return torch.cuda.max_memory_reserved(device_index)164
165def total_memory(self, device_index=None):166return torch.cuda.get_device_properties(device_index).total_memory167
168def _get_nvml_gpu_id(self, torch_gpu_id):169"""170credit: https://discuss.pytorch.org/t/making-pynvml-match-torch-device-ids-cuda-visible-devices/103020
171
172Remap torch device id to nvml device id, respecting CUDA_VISIBLE_DEVICES.
173
174If the latter isn't set return the same id
175"""
176# if CUDA_VISIBLE_DEVICES is used automagically remap the id since pynvml ignores this env var177if "CUDA_VISIBLE_DEVICES" in os.environ:178ids = list(map(int, os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",")))179return ids[torch_gpu_id] # remap180else:181return torch_gpu_id182
183def available_memory(self, device_index=None):184if pynvml:185if device_index is None:186device_index = self.current_device()187handle = pynvml.nvmlDeviceGetHandleByIndex(self._get_nvml_gpu_id(device_index))188info = pynvml.nvmlDeviceGetMemoryInfo(handle)189return info.free190else:191return self.total_memory(device_index) - self.memory_allocated(device_index)192
193# Data types194def is_bf16_supported(self):195if not torch.cuda.is_available():196return True197return torch.cuda.is_bf16_supported()198
199def is_fp16_supported(self):200if not torch.cuda.is_available():201return True202# See https://docs.nvidia.com/deeplearning/tensorrt/support-matrix/index.html#hardware-precision-matrix203# FP16 on compute capability 6.x is deprecated204allow_deprecated_fp16 = os.environ.get('DS_ALLOW_DEPRECATED_FP16', '0') == '1'205major, _ = torch.cuda.get_device_capability()206if major >= 7:207return True208elif major == 6 and allow_deprecated_fp16:209return True210else:211return False212
213def supported_dtypes(self):214supported_dtypes = [torch.float]215if self.is_fp16_supported():216supported_dtypes.append(torch.half)217if self.is_bf16_supported():218supported_dtypes.append(torch.bfloat16)219return supported_dtypes220
221# Misc222def amp(self):223if hasattr(torch.cuda, 'amp'):224return torch.cuda.amp225return None226
227def is_available(self):228return torch.cuda.is_available()229
230def range_push(self, msg):231if hasattr(torch.cuda.nvtx, 'range_push'):232return torch.cuda.nvtx.range_push(msg)233
234def range_pop(self):235if hasattr(torch.cuda.nvtx, 'range_pop'):236return torch.cuda.nvtx.range_pop()237
238def lazy_call(self, callback):239return torch.cuda._lazy_call(callback)240
241def communication_backend_name(self):242return self._communication_backend_name243
244def is_triton_supported(self):245major, _ = torch.cuda.get_device_capability()246if major >= 8:247return True248else:249return False250
251# Graph operations252def create_graph(self):253return torch.cuda.CUDAGraph()254
255def capture_to_graph(self, graph, pool=None, stream=None):256return torch.cuda.graph(graph, pool, stream)257
258def replay_graph(self, graph):259graph.replay()260return261
262# Tensor operations263
264@property265def BFloat16Tensor(self):266return functools.partial(torch.tensor, dtype=torch.bfloat16, device='cuda')267
268@property269def ByteTensor(self):270return functools.partial(torch.tensor, dtype=torch.uint8, device='cuda')271
272@property273def DoubleTensor(self):274return functools.partial(torch.tensor, dtype=torch.double, device='cuda')275
276@property277def FloatTensor(self):278return functools.partial(torch.tensor, dtype=torch.float, device='cuda')279
280@property281def HalfTensor(self):282return functools.partial(torch.tensor, dtype=torch.half, device='cuda')283
284@property285def IntTensor(self):286return functools.partial(torch.tensor, dtype=torch.int, device='cuda')287
288@property289def LongTensor(self):290return functools.partial(torch.tensor, dtype=torch.long, device='cuda')291
292def pin_memory(self, tensor, align_bytes=1):293return tensor.pin_memory()294
295def is_pinned(self, tensor):296return tensor.is_pinned()297
298def on_accelerator(self, tensor):299device_str = str(tensor.device)300if device_str.startswith('cuda:'):301return True302else:303return False304
305def op_builder_dir(self):306try:307# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed308# if successful this also means we're doing a local install and not JIT compile path309from op_builder import __deepspeed__ # noqa: F401 # type: ignore310return "op_builder"311except ImportError:312return "deepspeed.ops.op_builder"313
314# dict that holds class name <--> class type mapping i.e.315# 'AsyncIOBuilder': <class 'op_builder.async_io.AsyncIOBuilder'>316# this dict will be filled at init stage317class_dict = None318
319def _lazy_init_class_dict(self):320if self.class_dict is not None:321return322else:323self.class_dict = {}324# begin initialize for create_op_builder()325# put all valid class name <--> class type mapping into class_dict326op_builder_dir = self.op_builder_dir()327op_builder_module = importlib.import_module(op_builder_dir)328op_builder_absolute_path = os.path.dirname(op_builder_module.__file__)329for _, module_name, _ in pkgutil.iter_modules([op_builder_absolute_path]):330# avoid self references,331# skip sub_directories which contains ops for other backend(cpu, npu, etc.).332if module_name != 'all_ops' and module_name != 'builder' and not os.path.isdir(333os.path.join(op_builder_absolute_path, module_name)):334module = importlib.import_module("{}.{}".format(op_builder_dir, module_name))335for member_name in module.__dir__():336if member_name.endswith(337'Builder'338) and member_name != "OpBuilder" and member_name != "CUDAOpBuilder" and member_name != "TorchCPUOpBuilder": # avoid abstract classes339if not member_name in self.class_dict:340self.class_dict[member_name] = getattr(module, member_name)341# end initialize for create_op_builder()342
343# create an instance of op builder and return, name specified by class_name344def create_op_builder(self, class_name):345self._lazy_init_class_dict()346if class_name in self.class_dict:347return self.class_dict[class_name]()348else:349return None350
351# return an op builder class, name specified by class_name352def get_op_builder(self, class_name):353self._lazy_init_class_dict()354if class_name in self.class_dict:355return self.class_dict[class_name]356else:357return None358
359def build_extension(self):360from torch.utils.cpp_extension import BuildExtension361return BuildExtension362
363def export_envs(self):364return ['NCCL']365
366def visible_devices_envs(self):367return ['CUDA_VISIBLE_DEVICES']368
369def set_visible_devices_envs(self, current_env, local_accelerator_ids):370for env in self.visible_devices_envs():371current_env[env] = ",".join(map(str, local_accelerator_ids))372
373def get_compile_backend(self):374return self._compile_backend375
376def set_compile_backend(self, backend):377supported_backends = torch._dynamo.list_backends(exclude_tags=())378if backend in supported_backends:379self._compile_backend = backend380else:381raise ValueError(382f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}")383