DeepSpeed
Зеркало из https://github.com/microsoft/DeepSpeed
/
setup.py
333 строки · 12.2 Кб
1# Copyright (c) Microsoft Corporation.
2# SPDX-License-Identifier: Apache-2.0
3
4# DeepSpeed Team
5"""
6DeepSpeed library
7
8To build wheel on Windows:
91. Install pytorch, such as pytorch 1.12 + cuda 11.6.
102. Install visual cpp build tool.
113. Include cuda toolkit.
124. Launch cmd console with Administrator privilege for creating required symlink folders.
13
14
15Create a new wheel via the following command:
16build_win.bat
17
18The wheel will be located at: dist/*.whl
19"""
20
21import pathlib22import os23import shutil24import sys25import subprocess26from setuptools import setup, find_packages27from setuptools.command import egg_info28import time29import typing30import shlex31
32torch_available = True33try:34import torch35except ImportError:36torch_available = False37print('[WARNING] Unable to import torch, pre-compiling ops will be disabled. ' \38'Please visit https://pytorch.org/ to see how to properly install torch on your system.')39
40from op_builder import get_default_compute_capabilities, OpBuilder41from op_builder.all_ops import ALL_OPS, accelerator_name42from op_builder.builder import installed_cuda_version43
44from accelerator import get_accelerator45
46# Fetch rocm state.
47is_rocm_pytorch = OpBuilder.is_rocm_pytorch()48rocm_version = OpBuilder.installed_rocm_version()49
50RED_START = '\033[31m'51RED_END = '\033[0m'52ERROR = f"{RED_START} [ERROR] {RED_END}"53
54
55def abort(msg):56print(f"{ERROR} {msg}")57assert False, msg58
59
60def fetch_requirements(path):61with open(path, 'r') as fd:62return [r.strip() for r in fd.readlines()]63
64
65def is_env_set(key):66"""67Checks if an environment variable is set and not "".
68"""
69return bool(os.environ.get(key, None))70
71
72def get_env_if_set(key, default: typing.Any = ""):73"""74Returns an environment variable if it is set and not "",
75otherwise returns a default value. In contrast, the fallback
76parameter of os.environ.get() is skipped if the variable is set to "".
77"""
78return os.environ.get(key, None) or default79
80
81install_requires = fetch_requirements('requirements/requirements.txt')82extras_require = {83'1bit': [], # add cupy based on cuda/rocm version84'1bit_mpi': fetch_requirements('requirements/requirements-1bit-mpi.txt'),85'readthedocs': fetch_requirements('requirements/requirements-readthedocs.txt'),86'dev': fetch_requirements('requirements/requirements-dev.txt'),87'autotuning': fetch_requirements('requirements/requirements-autotuning.txt'),88'autotuning_ml': fetch_requirements('requirements/requirements-autotuning-ml.txt'),89'sparse_attn': fetch_requirements('requirements/requirements-sparse_attn.txt'),90'sparse': fetch_requirements('requirements/requirements-sparse_pruning.txt'),91'inf': fetch_requirements('requirements/requirements-inf.txt'),92'sd': fetch_requirements('requirements/requirements-sd.txt'),93'triton': fetch_requirements('requirements/requirements-triton.txt'),94}
95
96# Only install pynvml on nvidia gpus.
97if torch_available and get_accelerator().device_name() == 'cuda' and not is_rocm_pytorch:98install_requires.append('nvidia-ml-py')99
100# Add specific cupy version to both onebit extension variants.
101if torch_available and get_accelerator().device_name() == 'cuda':102cupy = None103if is_rocm_pytorch:104rocm_major, rocm_minor = rocm_version105# XXX cupy support for rocm 5 is not available yet.106if rocm_major <= 4:107cupy = f"cupy-rocm-{rocm_major}-{rocm_minor}"108else:109cuda_major_ver, cuda_minor_ver = installed_cuda_version()110if (cuda_major_ver < 11) or ((cuda_major_ver == 11) and (cuda_minor_ver < 3)):111cupy = f"cupy-cuda{cuda_major_ver}{cuda_minor_ver}"112else:113cupy = f"cupy-cuda{cuda_major_ver}x"114
115if cupy:116extras_require['1bit'].append(cupy)117extras_require['1bit_mpi'].append(cupy)118
119# Make an [all] extra that installs all needed dependencies.
120all_extras = set()121for extra in extras_require.items():122for req in extra[1]:123all_extras.add(req)124extras_require['all'] = list(all_extras)125
126cmdclass = {}127
128# For any pre-installed ops force disable ninja.
129if torch_available:130use_ninja = is_env_set("DS_ENABLE_NINJA")131cmdclass['build_ext'] = get_accelerator().build_extension().with_options(use_ninja=use_ninja)132
133if torch_available:134TORCH_MAJOR = torch.__version__.split('.')[0]135TORCH_MINOR = torch.__version__.split('.')[1]136else:137TORCH_MAJOR = "0"138TORCH_MINOR = "0"139
140if torch_available and not get_accelerator().device_name() == 'cuda':141# Fix to allow docker builds, similar to https://github.com/NVIDIA/apex/issues/486.142print("[WARNING] Torch did not find cuda available, if cross-compiling or running with cpu only "143"you can ignore this message. Adding compute capability for Pascal, Volta, and Turing "144"(compute capabilities 6.0, 6.1, 6.2)")145if not is_env_set("TORCH_CUDA_ARCH_LIST"):146os.environ["TORCH_CUDA_ARCH_LIST"] = get_default_compute_capabilities()147
148ext_modules = []149
150# Default to pre-install kernels to false so we rely on JIT on Linux, opposite on Windows.
151BUILD_OP_PLATFORM = 1 if sys.platform == "win32" else 0152BUILD_OP_DEFAULT = int(get_env_if_set('DS_BUILD_OPS', BUILD_OP_PLATFORM))153print(f"DS_BUILD_OPS={BUILD_OP_DEFAULT}")154
155if BUILD_OP_DEFAULT:156assert torch_available, "Unable to pre-compile ops without torch installed. Please install torch before attempting to pre-compile ops."157
158
159def command_exists(cmd):160if sys.platform == "win32":161safe_cmd = shlex.split(f'{cmd}')162result = subprocess.Popen(safe_cmd, stdout=subprocess.PIPE)163return result.wait() == 1164else:165safe_cmd = shlex.split(f"bash -c type {cmd}")166result = subprocess.Popen(safe_cmd, stdout=subprocess.PIPE)167return result.wait() == 0168
169
170def op_envvar(op_name):171assert hasattr(ALL_OPS[op_name], 'BUILD_VAR'), \172f"{op_name} is missing BUILD_VAR field"173return ALL_OPS[op_name].BUILD_VAR174
175
176def op_enabled(op_name):177env_var = op_envvar(op_name)178return int(get_env_if_set(env_var, BUILD_OP_DEFAULT))179
180
181install_ops = dict.fromkeys(ALL_OPS.keys(), False)182for op_name, builder in ALL_OPS.items():183op_compatible = builder.is_compatible()184
185# If op is requested but not available, throw an error.186if op_enabled(op_name) and not op_compatible:187env_var = op_envvar(op_name)188if not is_env_set(env_var):189builder.warning(f"Skip pre-compile of incompatible {op_name}; One can disable {op_name} with {env_var}=0")190continue191
192# If op is compatible but install is not enabled (JIT mode).193if is_rocm_pytorch and op_compatible and not op_enabled(op_name):194builder.hipify_extension()195
196# If op install enabled, add builder to extensions.197if op_enabled(op_name) and op_compatible:198assert torch_available, f"Unable to pre-compile {op_name}, please first install torch"199install_ops[op_name] = op_enabled(op_name)200ext_modules.append(builder.builder())201
202print(f'Install Ops={install_ops}')203
204# Write out version/git info.
205git_hash_cmd = shlex.split("bash -c git rev-parse --short HEAD")206git_branch_cmd = shlex.split("bash -c git rev-parse --abbrev-ref HEAD")207if command_exists('git') and not is_env_set('DS_BUILD_STRING'):208try:209result = subprocess.check_output(git_hash_cmd)210git_hash = result.decode('utf-8').strip()211result = subprocess.check_output(git_branch_cmd)212git_branch = result.decode('utf-8').strip()213except subprocess.CalledProcessError:214git_hash = "unknown"215git_branch = "unknown"216else:217git_hash = "unknown"218git_branch = "unknown"219
220if sys.platform == "win32":221shutil.rmtree('.\\deepspeed\\ops\\csrc', ignore_errors=True)222pathlib.Path('.\\deepspeed\\ops\\csrc').unlink(missing_ok=True)223shutil.copytree('.\\csrc', '.\\deepspeed\\ops\\csrc', dirs_exist_ok=True)224shutil.rmtree('.\\deepspeed\\ops\\op_builder', ignore_errors=True)225pathlib.Path('.\\deepspeed\\ops\\op_builder').unlink(missing_ok=True)226shutil.copytree('.\\op_builder', '.\\deepspeed\\ops\\op_builder', dirs_exist_ok=True)227shutil.rmtree('.\\deepspeed\\accelerator', ignore_errors=True)228pathlib.Path('.\\deepspeed\\accelerator').unlink(missing_ok=True)229shutil.copytree('.\\accelerator', '.\\deepspeed\\accelerator', dirs_exist_ok=True)230egg_info.manifest_maker.template = 'MANIFEST_win.in'231
232# Parse the DeepSpeed version string from version.txt.
233version_str = open('version.txt', 'r').read().strip()234
235# Build specifiers like .devX can be added at install time. Otherwise, add the git hash.
236# Example: DS_BUILD_STRING=".dev20201022" python setup.py sdist bdist_wheel.
237
238# Building wheel for distribution, update version file.
239if is_env_set('DS_BUILD_STRING'):240# Build string env specified, probably building for distribution.241with open('build.txt', 'w') as fd:242fd.write(os.environ['DS_BUILD_STRING'])243version_str += os.environ['DS_BUILD_STRING']244elif os.path.isfile('build.txt'):245# build.txt exists, probably installing from distribution.246with open('build.txt', 'r') as fd:247version_str += fd.read().strip()248else:249# None of the above, probably installing from source.250version_str += f'+{git_hash}'251
252torch_version = ".".join([TORCH_MAJOR, TORCH_MINOR])253bf16_support = False254# Set cuda_version to 0.0 if cpu-only.
255cuda_version = "0.0"256nccl_version = "0.0"257# Set hip_version to 0.0 if cpu-only.
258hip_version = "0.0"259if torch_available and torch.version.cuda is not None:260cuda_version = ".".join(torch.version.cuda.split('.')[:2])261if sys.platform != "win32":262if isinstance(torch.cuda.nccl.version(), int):263# This will break if minor version > 9.264nccl_version = ".".join(str(torch.cuda.nccl.version())[:2])265else:266nccl_version = ".".join(map(str, torch.cuda.nccl.version()[:2]))267if hasattr(torch.cuda, 'is_bf16_supported') and torch.cuda.is_available():268bf16_support = torch.cuda.is_bf16_supported()269if torch_available and hasattr(torch.version, 'hip') and torch.version.hip is not None:270hip_version = ".".join(torch.version.hip.split('.')[:2])271torch_info = {272"version": torch_version,273"bf16_support": bf16_support,274"cuda_version": cuda_version,275"nccl_version": nccl_version,276"hip_version": hip_version277}
278
279print(f"version={version_str}, git_hash={git_hash}, git_branch={git_branch}")280with open('deepspeed/git_version_info_installed.py', 'w') as fd:281fd.write(f"version='{version_str}'\n")282fd.write(f"git_hash='{git_hash}'\n")283fd.write(f"git_branch='{git_branch}'\n")284fd.write(f"installed_ops={install_ops}\n")285fd.write(f"accelerator_name='{accelerator_name}'\n")286fd.write(f"torch_info={torch_info}\n")287
288print(f'install_requires={install_requires}')289print(f'ext_modules={ext_modules}')290
291# Parse README.md to make long_description for PyPI page.
292thisdir = os.path.abspath(os.path.dirname(__file__))293with open(os.path.join(thisdir, 'README.md'), encoding='utf-8') as fin:294readme_text = fin.read()295
296if sys.platform == "win32":297scripts = ['bin/deepspeed.bat', 'bin/ds', 'bin/ds_report.bat', 'bin/ds_report']298else:299scripts = [300'bin/deepspeed', 'bin/deepspeed.pt', 'bin/ds', 'bin/ds_ssh', 'bin/ds_report', 'bin/ds_bench', 'bin/dsr',301'bin/ds_elastic', 'bin/ds_nvme_tune', 'bin/ds_io'302]303
304start_time = time.time()305
306setup(name='deepspeed',307version=version_str,308description='DeepSpeed library',309long_description=readme_text,310long_description_content_type='text/markdown',311author='DeepSpeed Team',312author_email='deepspeed-info@microsoft.com',313url='http://deepspeed.ai',314project_urls={315'Documentation': 'https://deepspeed.readthedocs.io',316'Source': 'https://github.com/microsoft/DeepSpeed',317},318install_requires=install_requires,319extras_require=extras_require,320packages=find_packages(include=['deepspeed', 'deepspeed.*']),321include_package_data=True,322scripts=scripts,323classifiers=[324'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7',325'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9',326'Programming Language :: Python :: 3.10'327],328license='Apache Software License 2.0',329ext_modules=ext_modules,330cmdclass=cmdclass)331
332end_time = time.time()333print(f'deepspeed build time = {end_time - start_time} secs')334