TransformerEngine
/
setup.py
617 строк · 19.1 Кб
1# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2#
3# See LICENSE for license information.
4
5"""Installation script."""
6
7import ctypes
8from functools import lru_cache
9import os
10from pathlib import Path
11import re
12import shutil
13import subprocess
14from subprocess import CalledProcessError
15import sys
16import sysconfig
17from typing import List, Optional, Tuple, Union
18
19import setuptools
20from setuptools.command.build_ext import build_ext
21
22from te_version import te_version
23
24# Project directory root
25root_path: Path = Path(__file__).resolve().parent
26
27@lru_cache(maxsize=1)
28def with_debug_build() -> bool:
29"""Whether to build with a debug configuration"""
30for arg in sys.argv:
31if arg == "--debug":
32sys.argv.remove(arg)
33return True
34if int(os.getenv("NVTE_BUILD_DEBUG", "0")):
35return True
36return False
37
38# Call once in global scope since this function manipulates the
39# command-line arguments. Future calls will use a cached value.
40with_debug_build()
41
42def found_cmake() -> bool:
43""""Check if valid CMake is available
44
45CMake 3.18 or newer is required.
46
47"""
48
49# Check if CMake is available
50try:
51_cmake_bin = cmake_bin()
52except FileNotFoundError:
53return False
54
55# Query CMake for version info
56output = subprocess.run(
57[_cmake_bin, "--version"],
58capture_output=True,
59check=True,
60universal_newlines=True,
61)
62match = re.search(r"version\s*([\d.]+)", output.stdout)
63version = match.group(1).split('.')
64version = tuple(int(v) for v in version)
65return version >= (3, 18)
66
67def cmake_bin() -> Path:
68"""Get CMake executable
69
70Throws FileNotFoundError if not found.
71
72"""
73
74# Search in CMake Python package
75_cmake_bin: Optional[Path] = None
76try:
77import cmake
78except ImportError:
79pass
80else:
81cmake_dir = Path(cmake.__file__).resolve().parent
82_cmake_bin = cmake_dir / "data" / "bin" / "cmake"
83if not _cmake_bin.is_file():
84_cmake_bin = None
85
86# Search in path
87if _cmake_bin is None:
88_cmake_bin = shutil.which("cmake")
89if _cmake_bin is not None:
90_cmake_bin = Path(_cmake_bin).resolve()
91
92# Return executable if found
93if _cmake_bin is None:
94raise FileNotFoundError("Could not find CMake executable")
95return _cmake_bin
96
97def found_ninja() -> bool:
98""""Check if Ninja is available"""
99return shutil.which("ninja") is not None
100
101def found_pybind11() -> bool:
102""""Check if pybind11 is available"""
103
104# Check if Python package is installed
105try:
106import pybind11
107except ImportError:
108pass
109else:
110return True
111
112# Check if CMake can find pybind11
113if not found_cmake():
114return False
115try:
116subprocess.run(
117[
118"cmake",
119"--find-package",
120"-DMODE=EXIST",
121"-DNAME=pybind11",
122"-DCOMPILER_ID=CXX",
123"-DLANGUAGE=CXX",
124],
125stdout=subprocess.DEVNULL,
126stderr=subprocess.DEVNULL,
127check=True,
128)
129except (CalledProcessError, OSError):
130pass
131else:
132return True
133return False
134
135def cuda_version() -> Tuple[int, ...]:
136"""CUDA Toolkit version as a (major, minor) tuple
137
138Throws FileNotFoundError if NVCC is not found.
139
140"""
141
142# Try finding NVCC
143nvcc_bin: Optional[Path] = None
144if nvcc_bin is None and os.getenv("CUDA_HOME"):
145# Check in CUDA_HOME
146cuda_home = Path(os.getenv("CUDA_HOME"))
147nvcc_bin = cuda_home / "bin" / "nvcc"
148if nvcc_bin is None:
149# Check if nvcc is in path
150nvcc_bin = shutil.which("nvcc")
151if nvcc_bin is not None:
152nvcc_bin = Path(nvcc_bin)
153if nvcc_bin is None:
154# Last-ditch guess in /usr/local/cuda
155cuda_home = Path("/usr/local/cuda")
156nvcc_bin = cuda_home / "bin" / "nvcc"
157if not nvcc_bin.is_file():
158raise FileNotFoundError(f"Could not find NVCC at {nvcc_bin}")
159
160# Query NVCC for version info
161output = subprocess.run(
162[nvcc_bin, "-V"],
163capture_output=True,
164check=True,
165universal_newlines=True,
166)
167match = re.search(r"release\s*([\d.]+)", output.stdout)
168version = match.group(1).split('.')
169return tuple(int(v) for v in version)
170
171@lru_cache(maxsize=1)
172def with_userbuffers() -> bool:
173"""Check if userbuffers support is enabled"""
174if int(os.getenv("NVTE_WITH_USERBUFFERS", "0")):
175assert os.getenv("MPI_HOME"), \
176"MPI_HOME must be set if NVTE_WITH_USERBUFFERS=1"
177return True
178return False
179
180@lru_cache(maxsize=1)
181def frameworks() -> List[str]:
182"""DL frameworks to build support for"""
183_frameworks: List[str] = []
184supported_frameworks = ["pytorch", "jax", "paddle"]
185
186# Check environment variable
187if os.getenv("NVTE_FRAMEWORK"):
188_frameworks.extend(os.getenv("NVTE_FRAMEWORK").split(","))
189
190# Check command-line arguments
191for arg in sys.argv.copy():
192if arg.startswith("--framework="):
193_frameworks.extend(arg.replace("--framework=", "").split(","))
194sys.argv.remove(arg)
195
196# Detect installed frameworks if not explicitly specified
197if not _frameworks:
198try:
199import torch
200except ImportError:
201pass
202else:
203_frameworks.append("pytorch")
204try:
205import jax
206except ImportError:
207pass
208else:
209_frameworks.append("jax")
210try:
211import paddle
212except ImportError:
213pass
214else:
215_frameworks.append("paddle")
216
217# Special framework names
218if "all" in _frameworks:
219_frameworks = supported_frameworks.copy()
220if "none" in _frameworks:
221_frameworks = []
222
223# Check that frameworks are valid
224_frameworks = [framework.lower() for framework in _frameworks]
225for framework in _frameworks:
226if framework not in supported_frameworks:
227raise ValueError(
228f"Transformer Engine does not support framework={framework}"
229)
230
231return _frameworks
232
233# Call once in global scope since this function manipulates the
234# command-line arguments. Future calls will use a cached value.
235frameworks()
236
237def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
238"""Setup Python dependencies
239
240Returns dependencies for build, runtime, and testing.
241
242"""
243
244# Common requirements
245setup_reqs: List[str] = []
246install_reqs: List[str] = [
247"pydantic",
248"importlib-metadata>=1.0; python_version<'3.8'",
249]
250test_reqs: List[str] = ["pytest"]
251
252def add_unique(l: List[str], vals: Union[str, List[str]]) -> None:
253"""Add entry to list if not already included"""
254if isinstance(vals, str):
255vals = [vals]
256for val in vals:
257if val not in l:
258l.append(val)
259
260# Requirements that may be installed outside of Python
261if not found_cmake():
262add_unique(setup_reqs, "cmake>=3.18")
263if not found_ninja():
264add_unique(setup_reqs, "ninja")
265
266# Framework-specific requirements
267if "pytorch" in frameworks():
268add_unique(install_reqs, ["torch", "flash-attn>=2.0.6,<=2.5.8,!=2.0.9,!=2.1.0"])
269add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"])
270if "jax" in frameworks():
271if not found_pybind11():
272add_unique(setup_reqs, "pybind11")
273add_unique(install_reqs, ["jax", "flax>=0.7.1"])
274add_unique(test_reqs, ["numpy", "praxis"])
275if "paddle" in frameworks():
276add_unique(install_reqs, "paddlepaddle-gpu")
277add_unique(test_reqs, "numpy")
278
279return setup_reqs, install_reqs, test_reqs
280
281
282class CMakeExtension(setuptools.Extension):
283"""CMake extension module"""
284
285def __init__(
286self,
287name: str,
288cmake_path: Path,
289cmake_flags: Optional[List[str]] = None,
290) -> None:
291super().__init__(name, sources=[]) # No work for base class
292self.cmake_path: Path = cmake_path
293self.cmake_flags: List[str] = [] if cmake_flags is None else cmake_flags
294
295def _build_cmake(self, build_dir: Path, install_dir: Path) -> None:
296
297# Make sure paths are str
298_cmake_bin = str(cmake_bin())
299cmake_path = str(self.cmake_path)
300build_dir = str(build_dir)
301install_dir = str(install_dir)
302
303# CMake configure command
304build_type = "Debug" if with_debug_build() else "Release"
305configure_command = [
306_cmake_bin,
307"-S",
308cmake_path,
309"-B",
310build_dir,
311f"-DPython_EXECUTABLE={sys.executable}",
312f"-DPython_INCLUDE_DIR={sysconfig.get_path('include')}",
313f"-DCMAKE_BUILD_TYPE={build_type}",
314f"-DCMAKE_INSTALL_PREFIX={install_dir}",
315]
316configure_command += self.cmake_flags
317if found_ninja():
318configure_command.append("-GNinja")
319try:
320import pybind11
321except ImportError:
322pass
323else:
324pybind11_dir = Path(pybind11.__file__).resolve().parent
325pybind11_dir = pybind11_dir / "share" / "cmake" / "pybind11"
326configure_command.append(f"-Dpybind11_DIR={pybind11_dir}")
327
328# CMake build and install commands
329build_command = [_cmake_bin, "--build", build_dir]
330install_command = [_cmake_bin, "--install", build_dir]
331
332# Run CMake commands
333for command in [configure_command, build_command, install_command]:
334print(f"Running command {' '.join(command)}")
335try:
336subprocess.run(command, cwd=build_dir, check=True)
337except (CalledProcessError, OSError) as e:
338raise RuntimeError(f"Error when running CMake: {e}")
339
340
341# PyTorch extension modules require special handling
342if "pytorch" in frameworks():
343from torch.utils.cpp_extension import BuildExtension
344elif "paddle" in frameworks():
345from paddle.utils.cpp_extension import BuildExtension
346else:
347from setuptools.command.build_ext import build_ext as BuildExtension
348
349
350class CMakeBuildExtension(BuildExtension):
351"""Setuptools command with support for CMake extension modules"""
352
353def __init__(self, *args, **kwargs) -> None:
354super().__init__(*args, **kwargs)
355
356def run(self) -> None:
357
358# Build CMake extensions
359for ext in self.extensions:
360if isinstance(ext, CMakeExtension):
361print(f"Building CMake extension {ext.name}")
362# Set up incremental builds for CMake extensions
363setup_dir = Path(__file__).resolve().parent
364build_dir = setup_dir / "build" / "cmake"
365build_dir.mkdir(parents=True, exist_ok=True) # Ensure the directory exists
366package_path = Path(self.get_ext_fullpath(ext.name))
367install_dir = package_path.resolve().parent
368ext._build_cmake(
369build_dir=build_dir,
370install_dir=install_dir,
371)
372
373# Paddle requires linker search path for libtransformer_engine.so
374paddle_ext = None
375if "paddle" in frameworks():
376for ext in self.extensions:
377if "paddle" in ext.name:
378ext.library_dirs.append(self.build_lib)
379paddle_ext = ext
380break
381
382# Build non-CMake extensions as usual
383all_extensions = self.extensions
384self.extensions = [
385ext for ext in self.extensions
386if not isinstance(ext, CMakeExtension)
387]
388super().run()
389self.extensions = all_extensions
390
391# Manually write stub file for Paddle extension
392if paddle_ext is not None:
393
394# Load libtransformer_engine.so to avoid linker errors
395for path in Path(self.build_lib).iterdir():
396if path.name.startswith("libtransformer_engine."):
397ctypes.CDLL(str(path), mode=ctypes.RTLD_GLOBAL)
398
399# Figure out stub file path
400module_name = paddle_ext.name
401assert module_name.endswith("_pd_"), \
402"Expected Paddle extension module to end with '_pd_'"
403stub_name = module_name[:-4] # remove '_pd_'
404stub_path = os.path.join(self.build_lib, stub_name + ".py")
405
406# Figure out library name
407# Note: This library doesn't actually exist. Paddle
408# internally reinserts the '_pd_' suffix.
409so_path = self.get_ext_fullpath(module_name)
410_, so_ext = os.path.splitext(so_path)
411lib_name = stub_name + so_ext
412
413# Write stub file
414print(f"Writing Paddle stub for {lib_name} into file {stub_path}")
415from paddle.utils.cpp_extension.extension_utils import custom_write_stub
416custom_write_stub(lib_name, stub_path)
417
418
419def setup_common_extension() -> CMakeExtension:
420"""Setup CMake extension for common library
421
422Also builds JAX or userbuffers support if needed.
423
424"""
425cmake_flags = []
426if "jax" in frameworks():
427cmake_flags.append("-DENABLE_JAX=ON")
428if with_userbuffers():
429cmake_flags.append("-DNVTE_WITH_USERBUFFERS=ON")
430return CMakeExtension(
431name="transformer_engine",
432cmake_path=root_path / "transformer_engine",
433cmake_flags=cmake_flags,
434)
435
436def _all_files_in_dir(path):
437return list(path.iterdir())
438
439def setup_pytorch_extension() -> setuptools.Extension:
440"""Setup CUDA extension for PyTorch support"""
441
442# Source files
443src_dir = root_path / "transformer_engine" / "pytorch" / "csrc"
444extensions_dir = src_dir / "extensions"
445sources = [
446src_dir / "common.cu",
447src_dir / "ts_fp8_op.cpp",
448# We need to compile system.cpp because the pytorch extension uses
449# transformer_engine::getenv. This is a workaround to avoid direct
450# linking with libtransformer_engine.so, as the pre-built PyTorch
451# wheel from conda or PyPI was not built with CXX11_ABI, and will
452# cause undefined symbol issues.
453root_path / "transformer_engine" / "common" / "util" / "system.cpp",
454] + \
455_all_files_in_dir(extensions_dir)
456
457# Header files
458include_dirs = [
459root_path / "transformer_engine" / "common" / "include",
460root_path / "transformer_engine" / "pytorch" / "csrc",
461root_path / "transformer_engine",
462root_path / "3rdparty" / "cudnn-frontend" / "include",
463]
464
465# Compiler flags
466cxx_flags = ["-O3"]
467nvcc_flags = [
468"-O3",
469"-gencode",
470"arch=compute_70,code=sm_70",
471"-U__CUDA_NO_HALF_OPERATORS__",
472"-U__CUDA_NO_HALF_CONVERSIONS__",
473"-U__CUDA_NO_BFLOAT16_OPERATORS__",
474"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
475"-U__CUDA_NO_BFLOAT162_OPERATORS__",
476"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
477"--expt-relaxed-constexpr",
478"--expt-extended-lambda",
479"--use_fast_math",
480]
481
482# Version-dependent CUDA options
483try:
484version = cuda_version()
485except FileNotFoundError:
486print("Could not determine CUDA Toolkit version")
487else:
488if version >= (11, 2):
489nvcc_flags.extend(["--threads", "4"])
490if version >= (11, 0):
491nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"])
492if version >= (11, 8):
493nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"])
494
495# userbuffers support
496if with_userbuffers():
497if os.getenv("MPI_HOME"):
498mpi_home = Path(os.getenv("MPI_HOME"))
499include_dirs.append(mpi_home / "include")
500cxx_flags.append("-DNVTE_WITH_USERBUFFERS")
501nvcc_flags.append("-DNVTE_WITH_USERBUFFERS")
502
503# Construct PyTorch CUDA extension
504sources = [str(path) for path in sources]
505include_dirs = [str(path) for path in include_dirs]
506from torch.utils.cpp_extension import CUDAExtension
507return CUDAExtension(
508name="transformer_engine_extensions",
509sources=sources,
510include_dirs=include_dirs,
511# libraries=["transformer_engine"], ### TODO (tmoon) Debug linker errors
512extra_compile_args={
513"cxx": cxx_flags,
514"nvcc": nvcc_flags,
515},
516)
517
518
519def setup_paddle_extension() -> setuptools.Extension:
520"""Setup CUDA extension for Paddle support"""
521
522# Source files
523src_dir = root_path / "transformer_engine" / "paddle" / "csrc"
524sources = [
525src_dir / "extensions.cu",
526src_dir / "common.cpp",
527src_dir / "custom_ops.cu",
528]
529
530# Header files
531include_dirs = [
532root_path / "transformer_engine" / "common" / "include",
533root_path / "transformer_engine" / "paddle" / "csrc",
534root_path / "transformer_engine",
535]
536
537# Compiler flags
538cxx_flags = ["-O3"]
539nvcc_flags = [
540"-O3",
541"-gencode",
542"arch=compute_70,code=sm_70",
543"-U__CUDA_NO_HALF_OPERATORS__",
544"-U__CUDA_NO_HALF_CONVERSIONS__",
545"-U__CUDA_NO_BFLOAT16_OPERATORS__",
546"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
547"-U__CUDA_NO_BFLOAT162_OPERATORS__",
548"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
549"--expt-relaxed-constexpr",
550"--expt-extended-lambda",
551"--use_fast_math",
552]
553
554# Version-dependent CUDA options
555try:
556version = cuda_version()
557except FileNotFoundError:
558print("Could not determine CUDA Toolkit version")
559else:
560if version >= (11, 2):
561nvcc_flags.extend(["--threads", "4"])
562if version >= (11, 0):
563nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"])
564if version >= (11, 8):
565nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"])
566
567# Construct Paddle CUDA extension
568sources = [str(path) for path in sources]
569include_dirs = [str(path) for path in include_dirs]
570from paddle.utils.cpp_extension import CUDAExtension
571ext = CUDAExtension(
572sources=sources,
573include_dirs=include_dirs,
574libraries=["transformer_engine"],
575extra_compile_args={
576"cxx": cxx_flags,
577"nvcc": nvcc_flags,
578},
579)
580ext.name = "transformer_engine_paddle_pd_"
581return ext
582
583def main():
584
585# Submodules to install
586packages = setuptools.find_packages(
587include=["transformer_engine", "transformer_engine.*"],
588)
589
590# Dependencies
591setup_requires, install_requires, test_requires = setup_requirements()
592
593# Extensions
594ext_modules = [setup_common_extension()]
595if "pytorch" in frameworks():
596ext_modules.append(setup_pytorch_extension())
597
598if "paddle" in frameworks():
599ext_modules.append(setup_paddle_extension())
600
601# Configure package
602setuptools.setup(
603name="transformer_engine",
604version=te_version(),
605packages=packages,
606description="Transformer acceleration library",
607ext_modules=ext_modules,
608cmdclass={"build_ext": CMakeBuildExtension},
609setup_requires=setup_requires,
610install_requires=install_requires,
611extras_require={"test": test_requires},
612license_files=("LICENSE",),
613)
614
615
616if __name__ == "__main__":
617main()
618