TransformerEngine

Форк
0
617 строк · 19.1 Кб
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
#
3
# See LICENSE for license information.
4

5
"""Installation script."""
6

7
import ctypes
8
from functools import lru_cache
9
import os
10
from pathlib import Path
11
import re
12
import shutil
13
import subprocess
14
from subprocess import CalledProcessError
15
import sys
16
import sysconfig
17
from typing import List, Optional, Tuple, Union
18

19
import setuptools
20
from setuptools.command.build_ext import build_ext
21

22
from te_version import te_version
23

24
# Project directory root
25
root_path: Path = Path(__file__).resolve().parent
26

27
@lru_cache(maxsize=1)
28
def with_debug_build() -> bool:
29
    """Whether to build with a debug configuration"""
30
    for arg in sys.argv:
31
        if arg == "--debug":
32
            sys.argv.remove(arg)
33
            return True
34
    if int(os.getenv("NVTE_BUILD_DEBUG", "0")):
35
        return True
36
    return False
37

38
# Call once in global scope since this function manipulates the
39
# command-line arguments. Future calls will use a cached value.
40
with_debug_build()
41

42
def found_cmake() -> bool:
43
    """"Check if valid CMake is available
44

45
    CMake 3.18 or newer is required.
46

47
    """
48

49
    # Check if CMake is available
50
    try:
51
        _cmake_bin = cmake_bin()
52
    except FileNotFoundError:
53
        return False
54

55
    # Query CMake for version info
56
    output = subprocess.run(
57
        [_cmake_bin, "--version"],
58
        capture_output=True,
59
        check=True,
60
        universal_newlines=True,
61
    )
62
    match = re.search(r"version\s*([\d.]+)", output.stdout)
63
    version = match.group(1).split('.')
64
    version = tuple(int(v) for v in version)
65
    return version >= (3, 18)
66

67
def cmake_bin() -> Path:
68
    """Get CMake executable
69

70
    Throws FileNotFoundError if not found.
71

72
    """
73

74
    # Search in CMake Python package
75
    _cmake_bin: Optional[Path] = None
76
    try:
77
        import cmake
78
    except ImportError:
79
        pass
80
    else:
81
        cmake_dir = Path(cmake.__file__).resolve().parent
82
        _cmake_bin = cmake_dir / "data" / "bin" / "cmake"
83
        if not _cmake_bin.is_file():
84
            _cmake_bin = None
85

86
    # Search in path
87
    if _cmake_bin is None:
88
        _cmake_bin = shutil.which("cmake")
89
        if _cmake_bin is not None:
90
            _cmake_bin = Path(_cmake_bin).resolve()
91

92
    # Return executable if found
93
    if _cmake_bin is None:
94
        raise FileNotFoundError("Could not find CMake executable")
95
    return _cmake_bin
96

97
def found_ninja() -> bool:
98
    """"Check if Ninja is available"""
99
    return shutil.which("ninja") is not None
100

101
def found_pybind11() -> bool:
102
    """"Check if pybind11 is available"""
103

104
    # Check if Python package is installed
105
    try:
106
        import pybind11
107
    except ImportError:
108
        pass
109
    else:
110
        return True
111

112
    # Check if CMake can find pybind11
113
    if not found_cmake():
114
        return False
115
    try:
116
        subprocess.run(
117
            [
118
                "cmake",
119
                "--find-package",
120
                "-DMODE=EXIST",
121
                "-DNAME=pybind11",
122
                "-DCOMPILER_ID=CXX",
123
                "-DLANGUAGE=CXX",
124
            ],
125
            stdout=subprocess.DEVNULL,
126
            stderr=subprocess.DEVNULL,
127
            check=True,
128
        )
129
    except (CalledProcessError, OSError):
130
        pass
131
    else:
132
        return True
133
    return False
134

135
def cuda_version() -> Tuple[int, ...]:
136
    """CUDA Toolkit version as a (major, minor) tuple
137

138
    Throws FileNotFoundError if NVCC is not found.
139

140
    """
141

142
    # Try finding NVCC
143
    nvcc_bin: Optional[Path] = None
144
    if nvcc_bin is None and os.getenv("CUDA_HOME"):
145
        # Check in CUDA_HOME
146
        cuda_home = Path(os.getenv("CUDA_HOME"))
147
        nvcc_bin = cuda_home / "bin" / "nvcc"
148
    if nvcc_bin is None:
149
        # Check if nvcc is in path
150
        nvcc_bin = shutil.which("nvcc")
151
        if nvcc_bin is not None:
152
            nvcc_bin = Path(nvcc_bin)
153
    if nvcc_bin is None:
154
        # Last-ditch guess in /usr/local/cuda
155
        cuda_home = Path("/usr/local/cuda")
156
        nvcc_bin = cuda_home / "bin" / "nvcc"
157
    if not nvcc_bin.is_file():
158
        raise FileNotFoundError(f"Could not find NVCC at {nvcc_bin}")
159

160
    # Query NVCC for version info
161
    output = subprocess.run(
162
        [nvcc_bin, "-V"],
163
        capture_output=True,
164
        check=True,
165
        universal_newlines=True,
166
    )
167
    match = re.search(r"release\s*([\d.]+)", output.stdout)
168
    version = match.group(1).split('.')
169
    return tuple(int(v) for v in version)
170

171
@lru_cache(maxsize=1)
172
def with_userbuffers() -> bool:
173
    """Check if userbuffers support is enabled"""
174
    if int(os.getenv("NVTE_WITH_USERBUFFERS", "0")):
175
        assert os.getenv("MPI_HOME"), \
176
            "MPI_HOME must be set if NVTE_WITH_USERBUFFERS=1"
177
        return True
178
    return False
179

180
@lru_cache(maxsize=1)
181
def frameworks() -> List[str]:
182
    """DL frameworks to build support for"""
183
    _frameworks: List[str] = []
184
    supported_frameworks = ["pytorch", "jax", "paddle"]
185

186
    # Check environment variable
187
    if os.getenv("NVTE_FRAMEWORK"):
188
        _frameworks.extend(os.getenv("NVTE_FRAMEWORK").split(","))
189

190
    # Check command-line arguments
191
    for arg in sys.argv.copy():
192
        if arg.startswith("--framework="):
193
            _frameworks.extend(arg.replace("--framework=", "").split(","))
194
            sys.argv.remove(arg)
195

196
    # Detect installed frameworks if not explicitly specified
197
    if not _frameworks:
198
        try:
199
            import torch
200
        except ImportError:
201
            pass
202
        else:
203
            _frameworks.append("pytorch")
204
        try:
205
            import jax
206
        except ImportError:
207
            pass
208
        else:
209
            _frameworks.append("jax")
210
        try:
211
            import paddle
212
        except ImportError:
213
            pass
214
        else:
215
            _frameworks.append("paddle")
216

217
    # Special framework names
218
    if "all" in _frameworks:
219
        _frameworks = supported_frameworks.copy()
220
    if "none" in _frameworks:
221
        _frameworks = []
222

223
    # Check that frameworks are valid
224
    _frameworks = [framework.lower() for framework in _frameworks]
225
    for framework in _frameworks:
226
        if framework not in supported_frameworks:
227
            raise ValueError(
228
                f"Transformer Engine does not support framework={framework}"
229
            )
230

231
    return _frameworks
232

233
# Call once in global scope since this function manipulates the
234
# command-line arguments. Future calls will use a cached value.
235
frameworks()
236

237
def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
238
    """Setup Python dependencies
239

240
    Returns dependencies for build, runtime, and testing.
241

242
    """
243

244
    # Common requirements
245
    setup_reqs: List[str] = []
246
    install_reqs: List[str] = [
247
        "pydantic",
248
        "importlib-metadata>=1.0; python_version<'3.8'",
249
    ]
250
    test_reqs: List[str] = ["pytest"]
251

252
    def add_unique(l: List[str], vals: Union[str, List[str]]) -> None:
253
        """Add entry to list if not already included"""
254
        if isinstance(vals, str):
255
            vals = [vals]
256
        for val in vals:
257
            if val not in l:
258
                l.append(val)
259

260
    # Requirements that may be installed outside of Python
261
    if not found_cmake():
262
        add_unique(setup_reqs, "cmake>=3.18")
263
    if not found_ninja():
264
        add_unique(setup_reqs, "ninja")
265

266
    # Framework-specific requirements
267
    if "pytorch" in frameworks():
268
        add_unique(install_reqs, ["torch", "flash-attn>=2.0.6,<=2.5.8,!=2.0.9,!=2.1.0"])
269
        add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"])
270
    if "jax" in frameworks():
271
        if not found_pybind11():
272
            add_unique(setup_reqs, "pybind11")
273
        add_unique(install_reqs, ["jax", "flax>=0.7.1"])
274
        add_unique(test_reqs, ["numpy", "praxis"])
275
    if "paddle" in frameworks():
276
        add_unique(install_reqs, "paddlepaddle-gpu")
277
        add_unique(test_reqs, "numpy")
278

279
    return setup_reqs, install_reqs, test_reqs
280

281

282
class CMakeExtension(setuptools.Extension):
283
    """CMake extension module"""
284

285
    def __init__(
286
            self,
287
            name: str,
288
            cmake_path: Path,
289
            cmake_flags: Optional[List[str]] = None,
290
    ) -> None:
291
        super().__init__(name, sources=[])  # No work for base class
292
        self.cmake_path: Path = cmake_path
293
        self.cmake_flags: List[str] = [] if cmake_flags is None else cmake_flags
294

295
    def _build_cmake(self, build_dir: Path, install_dir: Path) -> None:
296

297
        # Make sure paths are str
298
        _cmake_bin = str(cmake_bin())
299
        cmake_path = str(self.cmake_path)
300
        build_dir = str(build_dir)
301
        install_dir = str(install_dir)
302

303
        # CMake configure command
304
        build_type = "Debug" if with_debug_build() else "Release"
305
        configure_command = [
306
            _cmake_bin,
307
            "-S",
308
            cmake_path,
309
            "-B",
310
            build_dir,
311
            f"-DPython_EXECUTABLE={sys.executable}",
312
            f"-DPython_INCLUDE_DIR={sysconfig.get_path('include')}",
313
            f"-DCMAKE_BUILD_TYPE={build_type}",
314
            f"-DCMAKE_INSTALL_PREFIX={install_dir}",
315
        ]
316
        configure_command += self.cmake_flags
317
        if found_ninja():
318
            configure_command.append("-GNinja")
319
        try:
320
            import pybind11
321
        except ImportError:
322
            pass
323
        else:
324
            pybind11_dir = Path(pybind11.__file__).resolve().parent
325
            pybind11_dir = pybind11_dir / "share" / "cmake" / "pybind11"
326
            configure_command.append(f"-Dpybind11_DIR={pybind11_dir}")
327

328
        # CMake build and install commands
329
        build_command = [_cmake_bin, "--build", build_dir]
330
        install_command = [_cmake_bin, "--install", build_dir]
331

332
        # Run CMake commands
333
        for command in [configure_command, build_command, install_command]:
334
            print(f"Running command {' '.join(command)}")
335
            try:
336
                subprocess.run(command, cwd=build_dir, check=True)
337
            except (CalledProcessError, OSError) as e:
338
                raise RuntimeError(f"Error when running CMake: {e}")
339

340

341
# PyTorch extension modules require special handling
342
if "pytorch" in frameworks():
343
    from torch.utils.cpp_extension import BuildExtension
344
elif "paddle" in frameworks():
345
    from paddle.utils.cpp_extension import BuildExtension
346
else:
347
    from setuptools.command.build_ext import build_ext as BuildExtension
348

349

350
class CMakeBuildExtension(BuildExtension):
351
    """Setuptools command with support for CMake extension modules"""
352

353
    def __init__(self, *args, **kwargs) -> None:
354
        super().__init__(*args, **kwargs)
355

356
    def run(self) -> None:
357

358
        # Build CMake extensions
359
        for ext in self.extensions:
360
            if isinstance(ext, CMakeExtension):
361
                print(f"Building CMake extension {ext.name}")
362
                # Set up incremental builds for CMake extensions
363
                setup_dir = Path(__file__).resolve().parent
364
                build_dir = setup_dir / "build" / "cmake"
365
                build_dir.mkdir(parents=True, exist_ok=True)  # Ensure the directory exists
366
                package_path = Path(self.get_ext_fullpath(ext.name))
367
                install_dir = package_path.resolve().parent
368
                ext._build_cmake(
369
                    build_dir=build_dir,
370
                    install_dir=install_dir,
371
                )
372

373
        # Paddle requires linker search path for libtransformer_engine.so
374
        paddle_ext = None
375
        if "paddle" in frameworks():
376
            for ext in self.extensions:
377
                if "paddle" in ext.name:
378
                    ext.library_dirs.append(self.build_lib)
379
                    paddle_ext = ext
380
                    break
381

382
        # Build non-CMake extensions as usual
383
        all_extensions = self.extensions
384
        self.extensions = [
385
            ext for ext in self.extensions
386
            if not isinstance(ext, CMakeExtension)
387
        ]
388
        super().run()
389
        self.extensions = all_extensions
390

391
        # Manually write stub file for Paddle extension
392
        if paddle_ext is not None:
393

394
            # Load libtransformer_engine.so to avoid linker errors
395
            for path in Path(self.build_lib).iterdir():
396
                if path.name.startswith("libtransformer_engine."):
397
                    ctypes.CDLL(str(path), mode=ctypes.RTLD_GLOBAL)
398

399
            # Figure out stub file path
400
            module_name = paddle_ext.name
401
            assert module_name.endswith("_pd_"), \
402
                "Expected Paddle extension module to end with '_pd_'"
403
            stub_name = module_name[:-4]  # remove '_pd_'
404
            stub_path = os.path.join(self.build_lib, stub_name + ".py")
405

406
            # Figure out library name
407
            # Note: This library doesn't actually exist. Paddle
408
            # internally reinserts the '_pd_' suffix.
409
            so_path = self.get_ext_fullpath(module_name)
410
            _, so_ext = os.path.splitext(so_path)
411
            lib_name = stub_name + so_ext
412

413
            # Write stub file
414
            print(f"Writing Paddle stub for {lib_name} into file {stub_path}")
415
            from paddle.utils.cpp_extension.extension_utils import custom_write_stub
416
            custom_write_stub(lib_name, stub_path)
417

418

419
def setup_common_extension() -> CMakeExtension:
420
    """Setup CMake extension for common library
421

422
    Also builds JAX or userbuffers support if needed.
423

424
    """
425
    cmake_flags = []
426
    if "jax" in frameworks():
427
        cmake_flags.append("-DENABLE_JAX=ON")
428
    if with_userbuffers():
429
        cmake_flags.append("-DNVTE_WITH_USERBUFFERS=ON")
430
    return CMakeExtension(
431
        name="transformer_engine",
432
        cmake_path=root_path / "transformer_engine",
433
        cmake_flags=cmake_flags,
434
    )
435

436
def _all_files_in_dir(path):
437
    return list(path.iterdir())
438

439
def setup_pytorch_extension() -> setuptools.Extension:
440
    """Setup CUDA extension for PyTorch support"""
441

442
    # Source files
443
    src_dir = root_path / "transformer_engine" / "pytorch" / "csrc"
444
    extensions_dir = src_dir / "extensions"
445
    sources = [
446
        src_dir / "common.cu",
447
        src_dir / "ts_fp8_op.cpp",
448
        # We need to compile system.cpp because the pytorch extension uses
449
        # transformer_engine::getenv. This is a workaround to avoid direct
450
        # linking with libtransformer_engine.so, as the pre-built PyTorch
451
        # wheel from conda or PyPI was not built with CXX11_ABI, and will
452
        # cause undefined symbol issues.
453
        root_path / "transformer_engine" / "common" / "util" / "system.cpp",
454
    ] + \
455
    _all_files_in_dir(extensions_dir)
456

457
    # Header files
458
    include_dirs = [
459
        root_path / "transformer_engine" / "common" / "include",
460
        root_path / "transformer_engine" / "pytorch" / "csrc",
461
        root_path / "transformer_engine",
462
        root_path / "3rdparty" / "cudnn-frontend" / "include",
463
    ]
464

465
    # Compiler flags
466
    cxx_flags = ["-O3"]
467
    nvcc_flags = [
468
        "-O3",
469
        "-gencode",
470
        "arch=compute_70,code=sm_70",
471
        "-U__CUDA_NO_HALF_OPERATORS__",
472
        "-U__CUDA_NO_HALF_CONVERSIONS__",
473
        "-U__CUDA_NO_BFLOAT16_OPERATORS__",
474
        "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
475
        "-U__CUDA_NO_BFLOAT162_OPERATORS__",
476
        "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
477
        "--expt-relaxed-constexpr",
478
        "--expt-extended-lambda",
479
        "--use_fast_math",
480
    ]
481

482
    # Version-dependent CUDA options
483
    try:
484
        version = cuda_version()
485
    except FileNotFoundError:
486
        print("Could not determine CUDA Toolkit version")
487
    else:
488
        if version >= (11, 2):
489
            nvcc_flags.extend(["--threads", "4"])
490
        if version >= (11, 0):
491
            nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"])
492
        if version >= (11, 8):
493
            nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"])
494

495
    # userbuffers support
496
    if with_userbuffers():
497
        if os.getenv("MPI_HOME"):
498
            mpi_home = Path(os.getenv("MPI_HOME"))
499
            include_dirs.append(mpi_home / "include")
500
        cxx_flags.append("-DNVTE_WITH_USERBUFFERS")
501
        nvcc_flags.append("-DNVTE_WITH_USERBUFFERS")
502

503
    # Construct PyTorch CUDA extension
504
    sources = [str(path) for path in sources]
505
    include_dirs = [str(path) for path in include_dirs]
506
    from torch.utils.cpp_extension import CUDAExtension
507
    return CUDAExtension(
508
        name="transformer_engine_extensions",
509
        sources=sources,
510
        include_dirs=include_dirs,
511
        # libraries=["transformer_engine"], ### TODO (tmoon) Debug linker errors
512
        extra_compile_args={
513
            "cxx": cxx_flags,
514
            "nvcc": nvcc_flags,
515
        },
516
    )
517

518

519
def setup_paddle_extension() -> setuptools.Extension:
520
    """Setup CUDA extension for Paddle support"""
521

522
    # Source files
523
    src_dir = root_path / "transformer_engine" / "paddle" / "csrc"
524
    sources = [
525
        src_dir / "extensions.cu",
526
        src_dir / "common.cpp",
527
        src_dir / "custom_ops.cu",
528
    ]
529

530
    # Header files
531
    include_dirs = [
532
        root_path / "transformer_engine" / "common" / "include",
533
        root_path / "transformer_engine" / "paddle" / "csrc",
534
        root_path / "transformer_engine",
535
    ]
536

537
    # Compiler flags
538
    cxx_flags = ["-O3"]
539
    nvcc_flags = [
540
        "-O3",
541
        "-gencode",
542
        "arch=compute_70,code=sm_70",
543
        "-U__CUDA_NO_HALF_OPERATORS__",
544
        "-U__CUDA_NO_HALF_CONVERSIONS__",
545
        "-U__CUDA_NO_BFLOAT16_OPERATORS__",
546
        "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
547
        "-U__CUDA_NO_BFLOAT162_OPERATORS__",
548
        "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
549
        "--expt-relaxed-constexpr",
550
        "--expt-extended-lambda",
551
        "--use_fast_math",
552
    ]
553

554
    # Version-dependent CUDA options
555
    try:
556
        version = cuda_version()
557
    except FileNotFoundError:
558
        print("Could not determine CUDA Toolkit version")
559
    else:
560
        if version >= (11, 2):
561
            nvcc_flags.extend(["--threads", "4"])
562
        if version >= (11, 0):
563
            nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"])
564
        if version >= (11, 8):
565
            nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"])
566

567
    # Construct Paddle CUDA extension
568
    sources = [str(path) for path in sources]
569
    include_dirs = [str(path) for path in include_dirs]
570
    from paddle.utils.cpp_extension import CUDAExtension
571
    ext = CUDAExtension(
572
        sources=sources,
573
        include_dirs=include_dirs,
574
        libraries=["transformer_engine"],
575
        extra_compile_args={
576
            "cxx": cxx_flags,
577
            "nvcc": nvcc_flags,
578
        },
579
    )
580
    ext.name = "transformer_engine_paddle_pd_"
581
    return ext
582

583
def main():
584

585
    # Submodules to install
586
    packages = setuptools.find_packages(
587
        include=["transformer_engine", "transformer_engine.*"],
588
    )
589

590
    # Dependencies
591
    setup_requires, install_requires, test_requires = setup_requirements()
592

593
    # Extensions
594
    ext_modules = [setup_common_extension()]
595
    if "pytorch" in frameworks():
596
        ext_modules.append(setup_pytorch_extension())
597

598
    if "paddle" in frameworks():
599
        ext_modules.append(setup_paddle_extension())
600

601
    # Configure package
602
    setuptools.setup(
603
        name="transformer_engine",
604
        version=te_version(),
605
        packages=packages,
606
        description="Transformer acceleration library",
607
        ext_modules=ext_modules,
608
        cmdclass={"build_ext": CMakeBuildExtension},
609
        setup_requires=setup_requires,
610
        install_requires=install_requires,
611
        extras_require={"test": test_requires},
612
        license_files=("LICENSE",),
613
    )
614

615

616
if __name__ == "__main__":
617
    main()
618

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.