deepspeed
74 строки · 2.7 Кб
1# Copyright (c) Microsoft Corporation.
2# SPDX-License-Identifier: Apache-2.0
3
4# DeepSpeed Team
5
6from .builder import CUDAOpBuilder, installed_cuda_version
7
8
9class InferenceBuilder(CUDAOpBuilder):
10BUILD_VAR = "DS_BUILD_TRANSFORMER_INFERENCE"
11NAME = "transformer_inference"
12
13def __init__(self, name=None):
14name = self.NAME if name is None else name
15super().__init__(name=name)
16
17def absolute_name(self):
18return f'deepspeed.ops.transformer.inference.{self.NAME}_op'
19
20def is_compatible(self, verbose=True):
21try:
22import torch
23except ImportError:
24self.warning("Please install torch if trying to pre-compile inference kernels")
25return False
26
27cuda_okay = True
28if not self.is_rocm_pytorch() and torch.cuda.is_available():
29sys_cuda_major, _ = installed_cuda_version()
30torch_cuda_major = int(torch.version.cuda.split('.')[0])
31cuda_capability = torch.cuda.get_device_properties(0).major
32if cuda_capability < 6:
33self.warning("NVIDIA Inference is only supported on Pascal and newer architectures")
34cuda_okay = False
35if cuda_capability >= 8:
36if torch_cuda_major < 11 or sys_cuda_major < 11:
37self.warning("On Ampere and higher architectures please use CUDA 11+")
38cuda_okay = False
39return super().is_compatible(verbose) and cuda_okay
40
41def filter_ccs(self, ccs):
42ccs_retained = []
43ccs_pruned = []
44for cc in ccs:
45if int(cc[0]) >= 6:
46ccs_retained.append(cc)
47else:
48ccs_pruned.append(cc)
49if len(ccs_pruned) > 0:
50self.warning(f"Filtered compute capabilities {ccs_pruned}")
51return ccs_retained
52
53def sources(self):
54return [
55'csrc/transformer/inference/csrc/pt_binding.cpp',
56'csrc/transformer/inference/csrc/gelu.cu',
57'csrc/transformer/inference/csrc/relu.cu',
58'csrc/transformer/inference/csrc/layer_norm.cu',
59'csrc/transformer/inference/csrc/rms_norm.cu',
60'csrc/transformer/inference/csrc/softmax.cu',
61'csrc/transformer/inference/csrc/dequantize.cu',
62'csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu',
63'csrc/transformer/inference/csrc/transform.cu',
64'csrc/transformer/inference/csrc/pointwise_ops.cu',
65]
66
67def extra_ldflags(self):
68if not self.is_rocm_pytorch():
69return ['-lcurand']
70else:
71return []
72
73def include_paths(self):
74return ['csrc/transformer/inference/includes', 'csrc/includes']
75