pytorch

Форк
0
/
cpp_extension.py 
2425 строк · 102.6 Кб
1
# mypy: allow-untyped-defs
2
import copy
3
import glob
4
import importlib
5
import importlib.abc
6
import os
7
import re
8
import shlex
9
import shutil
10
import setuptools
11
import subprocess
12
import sys
13
import sysconfig
14
import warnings
15
import collections
16
from pathlib import Path
17
import errno
18

19
import torch
20
import torch._appdirs
21
from .file_baton import FileBaton
22
from ._cpp_extension_versioner import ExtensionVersioner
23
from .hipify import hipify_python
24
from .hipify.hipify_python import GeneratedFileCleaner
25
from typing import Dict, List, Optional, Union, Tuple
26
from torch.torch_version import TorchVersion, Version
27

28
from setuptools.command.build_ext import build_ext
29

30
IS_WINDOWS = sys.platform == 'win32'
31
IS_MACOS = sys.platform.startswith('darwin')
32
IS_LINUX = sys.platform.startswith('linux')
33
LIB_EXT = '.pyd' if IS_WINDOWS else '.so'
34
EXEC_EXT = '.exe' if IS_WINDOWS else ''
35
CLIB_PREFIX = '' if IS_WINDOWS else 'lib'
36
CLIB_EXT = '.dll' if IS_WINDOWS else '.so'
37
SHARED_FLAG = '/DLL' if IS_WINDOWS else '-shared'
38

39
_HERE = os.path.abspath(__file__)
40
_TORCH_PATH = os.path.dirname(os.path.dirname(_HERE))
41
TORCH_LIB_PATH = os.path.join(_TORCH_PATH, 'lib')
42

43

44
SUBPROCESS_DECODE_ARGS = ('oem',) if IS_WINDOWS else ()
45
MINIMUM_GCC_VERSION = (5, 0, 0)
46
MINIMUM_MSVC_VERSION = (19, 0, 24215)
47

48
VersionRange = Tuple[Tuple[int, ...], Tuple[int, ...]]
49
VersionMap = Dict[str, VersionRange]
50
# The following values were taken from the following GitHub gist that
51
# summarizes the minimum valid major versions of g++/clang++ for each supported
52
# CUDA version: https://gist.github.com/ax3l/9489132
53
# Or from include/crt/host_config.h in the CUDA SDK
54
# The second value is the exclusive(!) upper bound, i.e. min <= version < max
55
CUDA_GCC_VERSIONS: VersionMap = {
56
    '11.0': (MINIMUM_GCC_VERSION, (10, 0)),
57
    '11.1': (MINIMUM_GCC_VERSION, (11, 0)),
58
    '11.2': (MINIMUM_GCC_VERSION, (11, 0)),
59
    '11.3': (MINIMUM_GCC_VERSION, (11, 0)),
60
    '11.4': ((6, 0, 0), (12, 0)),
61
    '11.5': ((6, 0, 0), (12, 0)),
62
    '11.6': ((6, 0, 0), (12, 0)),
63
    '11.7': ((6, 0, 0), (12, 0)),
64
}
65

66
MINIMUM_CLANG_VERSION = (3, 3, 0)
67
CUDA_CLANG_VERSIONS: VersionMap = {
68
    '11.1': (MINIMUM_CLANG_VERSION, (11, 0)),
69
    '11.2': (MINIMUM_CLANG_VERSION, (12, 0)),
70
    '11.3': (MINIMUM_CLANG_VERSION, (12, 0)),
71
    '11.4': (MINIMUM_CLANG_VERSION, (13, 0)),
72
    '11.5': (MINIMUM_CLANG_VERSION, (13, 0)),
73
    '11.6': (MINIMUM_CLANG_VERSION, (14, 0)),
74
    '11.7': (MINIMUM_CLANG_VERSION, (14, 0)),
75
}
76

77
__all__ = ["get_default_build_root", "check_compiler_ok_for_platform", "get_compiler_abi_compatibility_and_version", "BuildExtension",
78
           "CppExtension", "CUDAExtension", "include_paths", "library_paths", "load", "load_inline", "is_ninja_available",
79
           "verify_ninja_availability", "remove_extension_h_precompiler_headers", "get_cxx_compiler", "check_compiler_is_gcc"]
80
# Taken directly from python stdlib < 3.9
81
# See https://github.com/pytorch/pytorch/issues/48617
82
def _nt_quote_args(args: Optional[List[str]]) -> List[str]:
83
    """Quote command-line arguments for DOS/Windows conventions.
84

85
    Just wraps every argument which contains blanks in double quotes, and
86
    returns a new argument list.
87
    """
88
    # Cover None-type
89
    if not args:
90
        return []
91
    return [f'"{arg}"' if ' ' in arg else arg for arg in args]
92

93
def _find_cuda_home() -> Optional[str]:
94
    """Find the CUDA install path."""
95
    # Guess #1
96
    cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
97
    if cuda_home is None:
98
        # Guess #2
99
        nvcc_path = shutil.which("nvcc")
100
        if nvcc_path is not None:
101
            cuda_home = os.path.dirname(os.path.dirname(nvcc_path))
102
        else:
103
            # Guess #3
104
            if IS_WINDOWS:
105
                cuda_homes = glob.glob(
106
                    'C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*')
107
                if len(cuda_homes) == 0:
108
                    cuda_home = ''
109
                else:
110
                    cuda_home = cuda_homes[0]
111
            else:
112
                cuda_home = '/usr/local/cuda'
113
            if not os.path.exists(cuda_home):
114
                cuda_home = None
115
    if cuda_home and not torch.cuda.is_available():
116
        print(f"No CUDA runtime is found, using CUDA_HOME='{cuda_home}'",
117
              file=sys.stderr)
118
    return cuda_home
119

120
def _find_rocm_home() -> Optional[str]:
121
    """Find the ROCm install path."""
122
    # Guess #1
123
    rocm_home = os.environ.get('ROCM_HOME') or os.environ.get('ROCM_PATH')
124
    if rocm_home is None:
125
        # Guess #2
126
        hipcc_path = shutil.which('hipcc')
127
        if hipcc_path is not None:
128
            rocm_home = os.path.dirname(os.path.dirname(
129
                os.path.realpath(hipcc_path)))
130
            # can be either <ROCM_HOME>/hip/bin/hipcc or <ROCM_HOME>/bin/hipcc
131
            if os.path.basename(rocm_home) == 'hip':
132
                rocm_home = os.path.dirname(rocm_home)
133
        else:
134
            # Guess #3
135
            fallback_path = '/opt/rocm'
136
            if os.path.exists(fallback_path):
137
                rocm_home = fallback_path
138
    if rocm_home and torch.version.hip is None:
139
        print(f"No ROCm runtime is found, using ROCM_HOME='{rocm_home}'",
140
              file=sys.stderr)
141
    return rocm_home
142

143

144
def _join_rocm_home(*paths) -> str:
145
    """
146
    Join paths with ROCM_HOME, or raises an error if it ROCM_HOME is not set.
147

148
    This is basically a lazy way of raising an error for missing $ROCM_HOME
149
    only once we need to get any ROCm-specific path.
150
    """
151
    if ROCM_HOME is None:
152
        raise OSError('ROCM_HOME environment variable is not set. '
153
                      'Please set it to your ROCm install root.')
154
    elif IS_WINDOWS:
155
        raise OSError('Building PyTorch extensions using '
156
                      'ROCm and Windows is not supported.')
157
    return os.path.join(ROCM_HOME, *paths)
158

159

160
ABI_INCOMPATIBILITY_WARNING = '''
161

162
                               !! WARNING !!
163

164
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
165
Your compiler ({}) may be ABI-incompatible with PyTorch!
166
Please use a compiler that is ABI-compatible with GCC 5.0 and above.
167
See https://gcc.gnu.org/onlinedocs/libstdc++/manual/abi.html.
168

169
See https://gist.github.com/goldsborough/d466f43e8ffc948ff92de7486c5216d6
170
for instructions on how to install GCC 5 or higher.
171
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
172

173
                              !! WARNING !!
174
'''
175
WRONG_COMPILER_WARNING = '''
176

177
                               !! WARNING !!
178

179
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
180
Your compiler ({user_compiler}) is not compatible with the compiler Pytorch was
181
built with for this platform, which is {pytorch_compiler} on {platform}. Please
182
use {pytorch_compiler} to to compile your extension. Alternatively, you may
183
compile PyTorch from source using {user_compiler}, and then you can also use
184
{user_compiler} to compile your extension.
185

186
See https://github.com/pytorch/pytorch/blob/master/CONTRIBUTING.md for help
187
with compiling PyTorch from source.
188
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
189

190
                              !! WARNING !!
191
'''
192
CUDA_MISMATCH_MESSAGE = '''
193
The detected CUDA version ({0}) mismatches the version that was used to compile
194
PyTorch ({1}). Please make sure to use the same CUDA versions.
195
'''
196
CUDA_MISMATCH_WARN = "The detected CUDA version ({0}) has a minor version mismatch with the version that was used to compile PyTorch ({1}). Most likely this shouldn't be a problem."
197
CUDA_NOT_FOUND_MESSAGE = '''
198
CUDA was not found on the system, please set the CUDA_HOME or the CUDA_PATH
199
environment variable or add NVCC to your system PATH. The extension compilation will fail.
200
'''
201
ROCM_HOME = _find_rocm_home()
202
HIP_HOME = _join_rocm_home('hip') if ROCM_HOME else None
203
IS_HIP_EXTENSION = True if ((ROCM_HOME is not None) and (torch.version.hip is not None)) else False
204
ROCM_VERSION = None
205
if torch.version.hip is not None:
206
    ROCM_VERSION = tuple(int(v) for v in torch.version.hip.split('.')[:2])
207

208
CUDA_HOME = _find_cuda_home() if torch.cuda._is_compiled() else None
209
CUDNN_HOME = os.environ.get('CUDNN_HOME') or os.environ.get('CUDNN_PATH')
210
# PyTorch releases have the version pattern major.minor.patch, whereas when
211
# PyTorch is built from source, we append the git commit hash, which gives
212
# it the below pattern.
213
BUILT_FROM_SOURCE_VERSION_PATTERN = re.compile(r'\d+\.\d+\.\d+\w+\+\w+')
214

215
COMMON_MSVC_FLAGS = ['/MD', '/wd4819', '/wd4251', '/wd4244', '/wd4267', '/wd4275', '/wd4018', '/wd4190', '/wd4624', '/wd4067', '/wd4068', '/EHsc']
216

217
MSVC_IGNORE_CUDAFE_WARNINGS = [
218
    'base_class_has_different_dll_interface',
219
    'field_without_dll_interface',
220
    'dll_interface_conflict_none_assumed',
221
    'dll_interface_conflict_dllexport_assumed'
222
]
223

224
COMMON_NVCC_FLAGS = [
225
    '-D__CUDA_NO_HALF_OPERATORS__',
226
    '-D__CUDA_NO_HALF_CONVERSIONS__',
227
    '-D__CUDA_NO_BFLOAT16_CONVERSIONS__',
228
    '-D__CUDA_NO_HALF2_OPERATORS__',
229
    '--expt-relaxed-constexpr'
230
]
231

232
COMMON_HIP_FLAGS = [
233
    '-fPIC',
234
    '-D__HIP_PLATFORM_AMD__=1',
235
    '-DUSE_ROCM=1',
236
    '-DHIPBLAS_V2',
237
]
238

239
COMMON_HIPCC_FLAGS = [
240
    '-DCUDA_HAS_FP16=1',
241
    '-D__HIP_NO_HALF_OPERATORS__=1',
242
    '-D__HIP_NO_HALF_CONVERSIONS__=1',
243
]
244

245
JIT_EXTENSION_VERSIONER = ExtensionVersioner()
246

247
PLAT_TO_VCVARS = {
248
    'win32' : 'x86',
249
    'win-amd64' : 'x86_amd64',
250
}
251

252
def get_cxx_compiler():
253
    if IS_WINDOWS:
254
        compiler = os.environ.get('CXX', 'cl')
255
    else:
256
        compiler = os.environ.get('CXX', 'c++')
257
    return compiler
258

259
def _is_binary_build() -> bool:
260
    return not BUILT_FROM_SOURCE_VERSION_PATTERN.match(torch.version.__version__)
261

262

263
def _accepted_compilers_for_platform() -> List[str]:
264
    # gnu-c++ and gnu-cc are the conda gcc compilers
265
    return ['clang++', 'clang'] if IS_MACOS else ['g++', 'gcc', 'gnu-c++', 'gnu-cc', 'clang++', 'clang']
266

267
def _maybe_write(filename, new_content):
268
    r'''
269
    Equivalent to writing the content into the file but will not touch the file
270
    if it already had the right content (to avoid triggering recompile).
271
    '''
272
    if os.path.exists(filename):
273
        with open(filename) as f:
274
            content = f.read()
275

276
        if content == new_content:
277
            # The file already contains the right thing!
278
            return
279

280
    with open(filename, 'w') as source_file:
281
        source_file.write(new_content)
282

283
def get_default_build_root() -> str:
284
    """
285
    Return the path to the root folder under which extensions will built.
286

287
    For each extension module built, there will be one folder underneath the
288
    folder returned by this function. For example, if ``p`` is the path
289
    returned by this function and ``ext`` the name of an extension, the build
290
    folder for the extension will be ``p/ext``.
291

292
    This directory is **user-specific** so that multiple users on the same
293
    machine won't meet permission issues.
294
    """
295
    return os.path.realpath(torch._appdirs.user_cache_dir(appname='torch_extensions'))
296

297

298
def check_compiler_ok_for_platform(compiler: str) -> bool:
299
    """
300
    Verify that the compiler is the expected one for the current platform.
301

302
    Args:
303
        compiler (str): The compiler executable to check.
304

305
    Returns:
306
        True if the compiler is gcc/g++ on Linux or clang/clang++ on macOS,
307
        and always True for Windows.
308
    """
309
    if IS_WINDOWS:
310
        return True
311
    compiler_path = shutil.which(compiler)
312
    if compiler_path is None:
313
        return False
314
    # Use os.path.realpath to resolve any symlinks, in particular from 'c++' to e.g. 'g++'.
315
    compiler_path = os.path.realpath(compiler_path)
316
    # Check the compiler name
317
    if any(name in compiler_path for name in _accepted_compilers_for_platform()):
318
        return True
319
    # If compiler wrapper is used try to infer the actual compiler by invoking it with -v flag
320
    env = os.environ.copy()
321
    env['LC_ALL'] = 'C'  # Don't localize output
322
    version_string = subprocess.check_output([compiler, '-v'], stderr=subprocess.STDOUT, env=env).decode(*SUBPROCESS_DECODE_ARGS)
323
    if IS_LINUX:
324
        # Check for 'gcc' or 'g++' for sccache wrapper
325
        pattern = re.compile("^COLLECT_GCC=(.*)$", re.MULTILINE)
326
        results = re.findall(pattern, version_string)
327
        if len(results) != 1:
328
            # Clang is also a supported compiler on Linux
329
            # Though on Ubuntu it's sometimes called "Ubuntu clang version"
330
            return 'clang version' in version_string
331
        compiler_path = os.path.realpath(results[0].strip())
332
        # On RHEL/CentOS c++ is a gcc compiler wrapper
333
        if os.path.basename(compiler_path) == 'c++' and 'gcc version' in version_string:
334
            return True
335
        return any(name in compiler_path for name in _accepted_compilers_for_platform())
336
    if IS_MACOS:
337
        # Check for 'clang' or 'clang++'
338
        return version_string.startswith("Apple clang")
339
    return False
340

341

342
def get_compiler_abi_compatibility_and_version(compiler) -> Tuple[bool, TorchVersion]:
343
    """
344
    Determine if the given compiler is ABI-compatible with PyTorch alongside its version.
345

346
    Args:
347
        compiler (str): The compiler executable name to check (e.g. ``g++``).
348
            Must be executable in a shell process.
349

350
    Returns:
351
        A tuple that contains a boolean that defines if the compiler is (likely) ABI-incompatible with PyTorch,
352
        followed by a `TorchVersion` string that contains the compiler version separated by dots.
353
    """
354
    if not _is_binary_build():
355
        return (True, TorchVersion('0.0.0'))
356
    if os.environ.get('TORCH_DONT_CHECK_COMPILER_ABI') in ['ON', '1', 'YES', 'TRUE', 'Y']:
357
        return (True, TorchVersion('0.0.0'))
358

359
    # First check if the compiler is one of the expected ones for the particular platform.
360
    if not check_compiler_ok_for_platform(compiler):
361
        warnings.warn(WRONG_COMPILER_WARNING.format(
362
            user_compiler=compiler,
363
            pytorch_compiler=_accepted_compilers_for_platform()[0],
364
            platform=sys.platform))
365
        return (False, TorchVersion('0.0.0'))
366

367
    if IS_MACOS:
368
        # There is no particular minimum version we need for clang, so we're good here.
369
        return (True, TorchVersion('0.0.0'))
370
    try:
371
        if IS_LINUX:
372
            minimum_required_version = MINIMUM_GCC_VERSION
373
            versionstr = subprocess.check_output([compiler, '-dumpfullversion', '-dumpversion'])
374
            version = versionstr.decode(*SUBPROCESS_DECODE_ARGS).strip().split('.')
375
        else:
376
            minimum_required_version = MINIMUM_MSVC_VERSION
377
            compiler_info = subprocess.check_output(compiler, stderr=subprocess.STDOUT)
378
            match = re.search(r'(\d+)\.(\d+)\.(\d+)', compiler_info.decode(*SUBPROCESS_DECODE_ARGS).strip())
379
            version = ['0', '0', '0'] if match is None else list(match.groups())
380
    except Exception:
381
        _, error, _ = sys.exc_info()
382
        warnings.warn(f'Error checking compiler version for {compiler}: {error}')
383
        return (False, TorchVersion('0.0.0'))
384

385
    if tuple(map(int, version)) >= minimum_required_version:
386
        return (True, TorchVersion('.'.join(version)))
387

388
    compiler = f'{compiler} {".".join(version)}'
389
    warnings.warn(ABI_INCOMPATIBILITY_WARNING.format(compiler))
390

391
    return (False, TorchVersion('.'.join(version)))
392

393

394
def _check_cuda_version(compiler_name: str, compiler_version: TorchVersion) -> None:
395
    if not CUDA_HOME:
396
        raise RuntimeError(CUDA_NOT_FOUND_MESSAGE)
397

398
    nvcc = os.path.join(CUDA_HOME, 'bin', 'nvcc')
399
    cuda_version_str = subprocess.check_output([nvcc, '--version']).strip().decode(*SUBPROCESS_DECODE_ARGS)
400
    cuda_version = re.search(r'release (\d+[.]\d+)', cuda_version_str)
401
    if cuda_version is None:
402
        return
403

404
    cuda_str_version = cuda_version.group(1)
405
    cuda_ver = Version(cuda_str_version)
406
    if torch.version.cuda is None:
407
        return
408

409
    torch_cuda_version = Version(torch.version.cuda)
410
    if cuda_ver != torch_cuda_version:
411
        # major/minor attributes are only available in setuptools>=49.4.0
412
        if getattr(cuda_ver, "major", None) is None:
413
            raise ValueError("setuptools>=49.4.0 is required")
414
        if cuda_ver.major != torch_cuda_version.major:
415
            raise RuntimeError(CUDA_MISMATCH_MESSAGE.format(cuda_str_version, torch.version.cuda))
416
        warnings.warn(CUDA_MISMATCH_WARN.format(cuda_str_version, torch.version.cuda))
417

418
    if not (sys.platform.startswith('linux') and
419
            os.environ.get('TORCH_DONT_CHECK_COMPILER_ABI') not in ['ON', '1', 'YES', 'TRUE', 'Y'] and
420
            _is_binary_build()):
421
        return
422

423
    cuda_compiler_bounds: VersionMap = CUDA_CLANG_VERSIONS if compiler_name.startswith('clang') else CUDA_GCC_VERSIONS
424

425
    if cuda_str_version not in cuda_compiler_bounds:
426
        warnings.warn(f'There are no {compiler_name} version bounds defined for CUDA version {cuda_str_version}')
427
    else:
428
        min_compiler_version, max_excl_compiler_version = cuda_compiler_bounds[cuda_str_version]
429
        # Special case for 11.4.0, which has lower compiler bounds than 11.4.1
430
        if "V11.4.48" in cuda_version_str and cuda_compiler_bounds == CUDA_GCC_VERSIONS:
431
            max_excl_compiler_version = (11, 0)
432
        min_compiler_version_str = '.'.join(map(str, min_compiler_version))
433
        max_excl_compiler_version_str = '.'.join(map(str, max_excl_compiler_version))
434

435
        version_bound_str = f'>={min_compiler_version_str}, <{max_excl_compiler_version_str}'
436

437
        if compiler_version < TorchVersion(min_compiler_version_str):
438
            raise RuntimeError(
439
                f'The current installed version of {compiler_name} ({compiler_version}) is less '
440
                f'than the minimum required version by CUDA {cuda_str_version} ({min_compiler_version_str}). '
441
                f'Please make sure to use an adequate version of {compiler_name} ({version_bound_str}).'
442
            )
443
        if compiler_version >= TorchVersion(max_excl_compiler_version_str):
444
            raise RuntimeError(
445
                f'The current installed version of {compiler_name} ({compiler_version}) is greater '
446
                f'than the maximum required version by CUDA {cuda_str_version}. '
447
                f'Please make sure to use an adequate version of {compiler_name} ({version_bound_str}).'
448
            )
449

450

451
class BuildExtension(build_ext):
452
    """
453
    A custom :mod:`setuptools` build extension .
454

455
    This :class:`setuptools.build_ext` subclass takes care of passing the
456
    minimum required compiler flags (e.g. ``-std=c++17``) as well as mixed
457
    C++/CUDA compilation (and support for CUDA files in general).
458

459
    When using :class:`BuildExtension`, it is allowed to supply a dictionary
460
    for ``extra_compile_args`` (rather than the usual list) that maps from
461
    languages (``cxx`` or ``nvcc``) to a list of additional compiler flags to
462
    supply to the compiler. This makes it possible to supply different flags to
463
    the C++ and CUDA compiler during mixed compilation.
464

465
    ``use_ninja`` (bool): If ``use_ninja`` is ``True`` (default), then we
466
    attempt to build using the Ninja backend. Ninja greatly speeds up
467
    compilation compared to the standard ``setuptools.build_ext``.
468
    Fallbacks to the standard distutils backend if Ninja is not available.
469

470
    .. note::
471
        By default, the Ninja backend uses #CPUS + 2 workers to build the
472
        extension. This may use up too many resources on some systems. One
473
        can control the number of workers by setting the `MAX_JOBS` environment
474
        variable to a non-negative number.
475
    """
476

477
    @classmethod
478
    def with_options(cls, **options):
479
        """Return a subclass with alternative constructor that extends any original keyword arguments to the original constructor with the given options."""
480
        class cls_with_options(cls):  # type: ignore[misc, valid-type]
481
            def __init__(self, *args, **kwargs):
482
                kwargs.update(options)
483
                super().__init__(*args, **kwargs)
484

485
        return cls_with_options
486

487
    def __init__(self, *args, **kwargs) -> None:
488
        super().__init__(*args, **kwargs)
489
        self.no_python_abi_suffix = kwargs.get("no_python_abi_suffix", False)
490

491
        self.use_ninja = kwargs.get('use_ninja', True)
492
        if self.use_ninja:
493
            # Test if we can use ninja. Fallback otherwise.
494
            msg = ('Attempted to use ninja as the BuildExtension backend but '
495
                   '{}. Falling back to using the slow distutils backend.')
496
            if not is_ninja_available():
497
                warnings.warn(msg.format('we could not find ninja.'))
498
                self.use_ninja = False
499

500
    def finalize_options(self) -> None:
501
        super().finalize_options()
502
        if self.use_ninja:
503
            self.force = True
504

505
    def build_extensions(self) -> None:
506
        compiler_name, compiler_version = self._check_abi()
507

508
        cuda_ext = False
509
        extension_iter = iter(self.extensions)
510
        extension = next(extension_iter, None)
511
        while not cuda_ext and extension:
512
            for source in extension.sources:
513
                _, ext = os.path.splitext(source)
514
                if ext == '.cu':
515
                    cuda_ext = True
516
                    break
517
            extension = next(extension_iter, None)
518

519
        if cuda_ext and not IS_HIP_EXTENSION:
520
            _check_cuda_version(compiler_name, compiler_version)
521

522
        for extension in self.extensions:
523
            # Ensure at least an empty list of flags for 'cxx' and 'nvcc' when
524
            # extra_compile_args is a dict. Otherwise, default torch flags do
525
            # not get passed. Necessary when only one of 'cxx' and 'nvcc' is
526
            # passed to extra_compile_args in CUDAExtension, i.e.
527
            #   CUDAExtension(..., extra_compile_args={'cxx': [...]})
528
            # or
529
            #   CUDAExtension(..., extra_compile_args={'nvcc': [...]})
530
            if isinstance(extension.extra_compile_args, dict):
531
                for ext in ['cxx', 'nvcc']:
532
                    if ext not in extension.extra_compile_args:
533
                        extension.extra_compile_args[ext] = []
534

535
            self._add_compile_flag(extension, '-DTORCH_API_INCLUDE_EXTENSION_H')
536
            # See note [Pybind11 ABI constants]
537
            for name in ["COMPILER_TYPE", "STDLIB", "BUILD_ABI"]:
538
                val = getattr(torch._C, f"_PYBIND11_{name}")
539
                if val is not None and not IS_WINDOWS:
540
                    self._add_compile_flag(extension, f'-DPYBIND11_{name}="{val}"')
541
            self._define_torch_extension_name(extension)
542
            self._add_gnu_cpp_abi_flag(extension)
543

544
            if 'nvcc_dlink' in extension.extra_compile_args:
545
                assert self.use_ninja, f"With dlink=True, ninja is required to build cuda extension {extension.name}."
546

547
        # Register .cu, .cuh, .hip, and .mm as valid source extensions.
548
        self.compiler.src_extensions += ['.cu', '.cuh', '.hip']
549
        if torch.backends.mps.is_built():
550
            self.compiler.src_extensions += ['.mm']
551
        # Save the original _compile method for later.
552
        if self.compiler.compiler_type == 'msvc':
553
            self.compiler._cpp_extensions += ['.cu', '.cuh']
554
            original_compile = self.compiler.compile
555
            original_spawn = self.compiler.spawn
556
        else:
557
            original_compile = self.compiler._compile
558

559
        def append_std17_if_no_std_present(cflags) -> None:
560
            # NVCC does not allow multiple -std to be passed, so we avoid
561
            # overriding the option if the user explicitly passed it.
562
            cpp_format_prefix = '/{}:' if self.compiler.compiler_type == 'msvc' else '-{}='
563
            cpp_flag_prefix = cpp_format_prefix.format('std')
564
            cpp_flag = cpp_flag_prefix + 'c++17'
565
            if not any(flag.startswith(cpp_flag_prefix) for flag in cflags):
566
                cflags.append(cpp_flag)
567

568
        def unix_cuda_flags(cflags):
569
            cflags = (COMMON_NVCC_FLAGS +
570
                      ['--compiler-options', "'-fPIC'"] +
571
                      cflags + _get_cuda_arch_flags(cflags))
572

573
            # NVCC does not allow multiple -ccbin/--compiler-bindir to be passed, so we avoid
574
            # overriding the option if the user explicitly passed it.
575
            _ccbin = os.getenv("CC")
576
            if (
577
                _ccbin is not None
578
                and not any(flag.startswith(('-ccbin', '--compiler-bindir')) for flag in cflags)
579
            ):
580
                cflags.extend(['-ccbin', _ccbin])
581

582
            return cflags
583

584
        def convert_to_absolute_paths_inplace(paths):
585
            # Helper function. See Note [Absolute include_dirs]
586
            if paths is not None:
587
                for i in range(len(paths)):
588
                    if not os.path.isabs(paths[i]):
589
                        paths[i] = os.path.abspath(paths[i])
590

591
        def unix_wrap_single_compile(obj, src, ext, cc_args, extra_postargs, pp_opts) -> None:
592
            # Copy before we make any modifications.
593
            cflags = copy.deepcopy(extra_postargs)
594
            try:
595
                original_compiler = self.compiler.compiler_so
596
                if _is_cuda_file(src):
597
                    nvcc = [_join_rocm_home('bin', 'hipcc') if IS_HIP_EXTENSION else _join_cuda_home('bin', 'nvcc')]
598
                    self.compiler.set_executable('compiler_so', nvcc)
599
                    if isinstance(cflags, dict):
600
                        cflags = cflags['nvcc']
601
                    if IS_HIP_EXTENSION:
602
                        cflags = COMMON_HIPCC_FLAGS + cflags + _get_rocm_arch_flags(cflags)
603
                    else:
604
                        cflags = unix_cuda_flags(cflags)
605
                elif isinstance(cflags, dict):
606
                    cflags = cflags['cxx']
607
                if IS_HIP_EXTENSION:
608
                    cflags = COMMON_HIP_FLAGS + cflags
609
                append_std17_if_no_std_present(cflags)
610

611
                original_compile(obj, src, ext, cc_args, cflags, pp_opts)
612
            finally:
613
                # Put the original compiler back in place.
614
                self.compiler.set_executable('compiler_so', original_compiler)
615

616
        def unix_wrap_ninja_compile(sources,
617
                                    output_dir=None,
618
                                    macros=None,
619
                                    include_dirs=None,
620
                                    debug=0,
621
                                    extra_preargs=None,
622
                                    extra_postargs=None,
623
                                    depends=None):
624
            r"""Compiles sources by outputting a ninja file and running it."""
625
            # NB: I copied some lines from self.compiler (which is an instance
626
            # of distutils.UnixCCompiler). See the following link.
627
            # https://github.com/python/cpython/blob/f03a8f8d5001963ad5b5b28dbd95497e9cc15596/Lib/distutils/ccompiler.py#L564-L567
628
            # This can be fragile, but a lot of other repos also do this
629
            # (see https://github.com/search?q=_setup_compile&type=Code)
630
            # so it is probably OK; we'll also get CI signal if/when
631
            # we update our python version (which is when distutils can be
632
            # upgraded)
633

634
            # Use absolute path for output_dir so that the object file paths
635
            # (`objects`) get generated with absolute paths.
636
            output_dir = os.path.abspath(output_dir)
637

638
            # See Note [Absolute include_dirs]
639
            convert_to_absolute_paths_inplace(self.compiler.include_dirs)
640

641
            _, objects, extra_postargs, pp_opts, _ = \
642
                self.compiler._setup_compile(output_dir, macros,
643
                                             include_dirs, sources,
644
                                             depends, extra_postargs)
645
            common_cflags = self.compiler._get_cc_args(pp_opts, debug, extra_preargs)
646
            extra_cc_cflags = self.compiler.compiler_so[1:]
647
            with_cuda = any(map(_is_cuda_file, sources))
648

649
            # extra_postargs can be either:
650
            # - a dict mapping cxx/nvcc to extra flags
651
            # - a list of extra flags.
652
            if isinstance(extra_postargs, dict):
653
                post_cflags = extra_postargs['cxx']
654
            else:
655
                post_cflags = list(extra_postargs)
656
            if IS_HIP_EXTENSION:
657
                post_cflags = COMMON_HIP_FLAGS + post_cflags
658
            append_std17_if_no_std_present(post_cflags)
659

660
            cuda_post_cflags = None
661
            cuda_cflags = None
662
            if with_cuda:
663
                cuda_cflags = common_cflags
664
                if isinstance(extra_postargs, dict):
665
                    cuda_post_cflags = extra_postargs['nvcc']
666
                else:
667
                    cuda_post_cflags = list(extra_postargs)
668
                if IS_HIP_EXTENSION:
669
                    cuda_post_cflags = cuda_post_cflags + _get_rocm_arch_flags(cuda_post_cflags)
670
                    cuda_post_cflags = COMMON_HIP_FLAGS + COMMON_HIPCC_FLAGS + cuda_post_cflags
671
                else:
672
                    cuda_post_cflags = unix_cuda_flags(cuda_post_cflags)
673
                append_std17_if_no_std_present(cuda_post_cflags)
674
                cuda_cflags = [shlex.quote(f) for f in cuda_cflags]
675
                cuda_post_cflags = [shlex.quote(f) for f in cuda_post_cflags]
676

677
            if isinstance(extra_postargs, dict) and 'nvcc_dlink' in extra_postargs:
678
                cuda_dlink_post_cflags = unix_cuda_flags(extra_postargs['nvcc_dlink'])
679
            else:
680
                cuda_dlink_post_cflags = None
681
            _write_ninja_file_and_compile_objects(
682
                sources=sources,
683
                objects=objects,
684
                cflags=[shlex.quote(f) for f in extra_cc_cflags + common_cflags],
685
                post_cflags=[shlex.quote(f) for f in post_cflags],
686
                cuda_cflags=cuda_cflags,
687
                cuda_post_cflags=cuda_post_cflags,
688
                cuda_dlink_post_cflags=cuda_dlink_post_cflags,
689
                build_directory=output_dir,
690
                verbose=True,
691
                with_cuda=with_cuda)
692

693
            # Return *all* object filenames, not just the ones we just built.
694
            return objects
695

696
        def win_cuda_flags(cflags):
697
            return (COMMON_NVCC_FLAGS +
698
                    cflags + _get_cuda_arch_flags(cflags))
699

700
        def win_wrap_single_compile(sources,
701
                                    output_dir=None,
702
                                    macros=None,
703
                                    include_dirs=None,
704
                                    debug=0,
705
                                    extra_preargs=None,
706
                                    extra_postargs=None,
707
                                    depends=None):
708

709
            self.cflags = copy.deepcopy(extra_postargs)
710
            extra_postargs = None
711

712
            def spawn(cmd):
713
                # Using regex to match src, obj and include files
714
                src_regex = re.compile('/T(p|c)(.*)')
715
                src_list = [
716
                    m.group(2) for m in (src_regex.match(elem) for elem in cmd)
717
                    if m
718
                ]
719

720
                obj_regex = re.compile('/Fo(.*)')
721
                obj_list = [
722
                    m.group(1) for m in (obj_regex.match(elem) for elem in cmd)
723
                    if m
724
                ]
725

726
                include_regex = re.compile(r'((\-|\/)I.*)')
727
                include_list = [
728
                    m.group(1)
729
                    for m in (include_regex.match(elem) for elem in cmd) if m
730
                ]
731

732
                if len(src_list) >= 1 and len(obj_list) >= 1:
733
                    src = src_list[0]
734
                    obj = obj_list[0]
735
                    if _is_cuda_file(src):
736
                        nvcc = _join_cuda_home('bin', 'nvcc')
737
                        if isinstance(self.cflags, dict):
738
                            cflags = self.cflags['nvcc']
739
                        elif isinstance(self.cflags, list):
740
                            cflags = self.cflags
741
                        else:
742
                            cflags = []
743

744
                        cflags = win_cuda_flags(cflags) + ['-std=c++17', '--use-local-env']
745
                        for flag in COMMON_MSVC_FLAGS:
746
                            cflags = ['-Xcompiler', flag] + cflags
747
                        for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS:
748
                            cflags = ['-Xcudafe', '--diag_suppress=' + ignore_warning] + cflags
749
                        cmd = [nvcc, '-c', src, '-o', obj] + include_list + cflags
750
                    elif isinstance(self.cflags, dict):
751
                        cflags = COMMON_MSVC_FLAGS + self.cflags['cxx']
752
                        append_std17_if_no_std_present(cflags)
753
                        cmd += cflags
754
                    elif isinstance(self.cflags, list):
755
                        cflags = COMMON_MSVC_FLAGS + self.cflags
756
                        append_std17_if_no_std_present(cflags)
757
                        cmd += cflags
758

759
                return original_spawn(cmd)
760

761
            try:
762
                self.compiler.spawn = spawn
763
                return original_compile(sources, output_dir, macros,
764
                                        include_dirs, debug, extra_preargs,
765
                                        extra_postargs, depends)
766
            finally:
767
                self.compiler.spawn = original_spawn
768

769
        def win_wrap_ninja_compile(sources,
770
                                   output_dir=None,
771
                                   macros=None,
772
                                   include_dirs=None,
773
                                   debug=0,
774
                                   extra_preargs=None,
775
                                   extra_postargs=None,
776
                                   depends=None):
777

778
            if not self.compiler.initialized:
779
                self.compiler.initialize()
780
            output_dir = os.path.abspath(output_dir)
781

782
            # Note [Absolute include_dirs]
783
            # Convert relative path in self.compiler.include_dirs to absolute path if any,
784
            # For ninja build, the build location is not local, the build happens
785
            # in a in script created build folder, relative path lost their correctness.
786
            # To be consistent with jit extension, we allow user to enter relative include_dirs
787
            # in setuptools.setup, and we convert the relative path to absolute path here
788
            convert_to_absolute_paths_inplace(self.compiler.include_dirs)
789

790
            _, objects, extra_postargs, pp_opts, _ = \
791
                self.compiler._setup_compile(output_dir, macros,
792
                                             include_dirs, sources,
793
                                             depends, extra_postargs)
794
            common_cflags = extra_preargs or []
795
            cflags = []
796
            if debug:
797
                cflags.extend(self.compiler.compile_options_debug)
798
            else:
799
                cflags.extend(self.compiler.compile_options)
800
            common_cflags.extend(COMMON_MSVC_FLAGS)
801
            cflags = cflags + common_cflags + pp_opts
802
            with_cuda = any(map(_is_cuda_file, sources))
803

804
            # extra_postargs can be either:
805
            # - a dict mapping cxx/nvcc to extra flags
806
            # - a list of extra flags.
807
            if isinstance(extra_postargs, dict):
808
                post_cflags = extra_postargs['cxx']
809
            else:
810
                post_cflags = list(extra_postargs)
811
            append_std17_if_no_std_present(post_cflags)
812

813
            cuda_post_cflags = None
814
            cuda_cflags = None
815
            if with_cuda:
816
                cuda_cflags = ['-std=c++17', '--use-local-env']
817
                for common_cflag in common_cflags:
818
                    cuda_cflags.append('-Xcompiler')
819
                    cuda_cflags.append(common_cflag)
820
                for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS:
821
                    cuda_cflags.append('-Xcudafe')
822
                    cuda_cflags.append('--diag_suppress=' + ignore_warning)
823
                cuda_cflags.extend(pp_opts)
824
                if isinstance(extra_postargs, dict):
825
                    cuda_post_cflags = extra_postargs['nvcc']
826
                else:
827
                    cuda_post_cflags = list(extra_postargs)
828
                cuda_post_cflags = win_cuda_flags(cuda_post_cflags)
829

830
            cflags = _nt_quote_args(cflags)
831
            post_cflags = _nt_quote_args(post_cflags)
832
            if with_cuda:
833
                cuda_cflags = _nt_quote_args(cuda_cflags)
834
                cuda_post_cflags = _nt_quote_args(cuda_post_cflags)
835
            if isinstance(extra_postargs, dict) and 'nvcc_dlink' in extra_postargs:
836
                cuda_dlink_post_cflags = win_cuda_flags(extra_postargs['nvcc_dlink'])
837
            else:
838
                cuda_dlink_post_cflags = None
839

840
            _write_ninja_file_and_compile_objects(
841
                sources=sources,
842
                objects=objects,
843
                cflags=cflags,
844
                post_cflags=post_cflags,
845
                cuda_cflags=cuda_cflags,
846
                cuda_post_cflags=cuda_post_cflags,
847
                cuda_dlink_post_cflags=cuda_dlink_post_cflags,
848
                build_directory=output_dir,
849
                verbose=True,
850
                with_cuda=with_cuda)
851

852
            # Return *all* object filenames, not just the ones we just built.
853
            return objects
854

855
        # Monkey-patch the _compile or compile method.
856
        # https://github.com/python/cpython/blob/dc0284ee8f7a270b6005467f26d8e5773d76e959/Lib/distutils/ccompiler.py#L511
857
        if self.compiler.compiler_type == 'msvc':
858
            if self.use_ninja:
859
                self.compiler.compile = win_wrap_ninja_compile
860
            else:
861
                self.compiler.compile = win_wrap_single_compile
862
        else:
863
            if self.use_ninja:
864
                self.compiler.compile = unix_wrap_ninja_compile
865
            else:
866
                self.compiler._compile = unix_wrap_single_compile
867

868
        build_ext.build_extensions(self)
869

870
    def get_ext_filename(self, ext_name):
871
        # Get the original shared library name. For Python 3, this name will be
872
        # suffixed with "<SOABI>.so", where <SOABI> will be something like
873
        # cpython-37m-x86_64-linux-gnu.
874
        ext_filename = super().get_ext_filename(ext_name)
875
        # If `no_python_abi_suffix` is `True`, we omit the Python 3 ABI
876
        # component. This makes building shared libraries with setuptools that
877
        # aren't Python modules nicer.
878
        if self.no_python_abi_suffix:
879
            # The parts will be e.g. ["my_extension", "cpython-37m-x86_64-linux-gnu", "so"].
880
            ext_filename_parts = ext_filename.split('.')
881
            # Omit the second to last element.
882
            without_abi = ext_filename_parts[:-2] + ext_filename_parts[-1:]
883
            ext_filename = '.'.join(without_abi)
884
        return ext_filename
885

886
    def _check_abi(self) -> Tuple[str, TorchVersion]:
887
        # On some platforms, like Windows, compiler_cxx is not available.
888
        if hasattr(self.compiler, 'compiler_cxx'):
889
            compiler = self.compiler.compiler_cxx[0]
890
        else:
891
            compiler = get_cxx_compiler()
892
        _, version = get_compiler_abi_compatibility_and_version(compiler)
893
        # Warn user if VC env is activated but `DISTUILS_USE_SDK` is not set.
894
        if IS_WINDOWS and 'VSCMD_ARG_TGT_ARCH' in os.environ and 'DISTUTILS_USE_SDK' not in os.environ:
895
            msg = ('It seems that the VC environment is activated but DISTUTILS_USE_SDK is not set.'
896
                   'This may lead to multiple activations of the VC env.'
897
                   'Please set `DISTUTILS_USE_SDK=1` and try again.')
898
            raise UserWarning(msg)
899
        return compiler, version
900

901
    def _add_compile_flag(self, extension, flag):
902
        extension.extra_compile_args = copy.deepcopy(extension.extra_compile_args)
903
        if isinstance(extension.extra_compile_args, dict):
904
            for args in extension.extra_compile_args.values():
905
                args.append(flag)
906
        else:
907
            extension.extra_compile_args.append(flag)
908

909
    def _define_torch_extension_name(self, extension):
910
        # pybind11 doesn't support dots in the names
911
        # so in order to support extensions in the packages
912
        # like torch._C, we take the last part of the string
913
        # as the library name
914
        names = extension.name.split('.')
915
        name = names[-1]
916
        define = f'-DTORCH_EXTENSION_NAME={name}'
917
        self._add_compile_flag(extension, define)
918

919
    def _add_gnu_cpp_abi_flag(self, extension):
920
        # use the same CXX ABI as what PyTorch was compiled with
921
        self._add_compile_flag(extension, '-D_GLIBCXX_USE_CXX11_ABI=' + str(int(torch._C._GLIBCXX_USE_CXX11_ABI)))
922

923

924
def CppExtension(name, sources, *args, **kwargs):
925
    """
926
    Create a :class:`setuptools.Extension` for C++.
927

928
    Convenience method that creates a :class:`setuptools.Extension` with the
929
    bare minimum (but often sufficient) arguments to build a C++ extension.
930

931
    All arguments are forwarded to the :class:`setuptools.Extension`
932
    constructor. Full list arguments can be found at
933
    https://setuptools.pypa.io/en/latest/userguide/ext_modules.html#extension-api-reference
934

935
    Example:
936
        >>> # xdoctest: +SKIP
937
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT)
938
        >>> from setuptools import setup
939
        >>> from torch.utils.cpp_extension import BuildExtension, CppExtension
940
        >>> setup(
941
        ...     name='extension',
942
        ...     ext_modules=[
943
        ...         CppExtension(
944
        ...             name='extension',
945
        ...             sources=['extension.cpp'],
946
        ...             extra_compile_args=['-g'],
947
        ...             extra_link_args=['-Wl,--no-as-needed', '-lm'])
948
        ...     ],
949
        ...     cmdclass={
950
        ...         'build_ext': BuildExtension
951
        ...     })
952
    """
953
    include_dirs = kwargs.get('include_dirs', [])
954
    include_dirs += include_paths()
955
    kwargs['include_dirs'] = include_dirs
956

957
    library_dirs = kwargs.get('library_dirs', [])
958
    library_dirs += library_paths()
959
    kwargs['library_dirs'] = library_dirs
960

961
    libraries = kwargs.get('libraries', [])
962
    libraries.append('c10')
963
    libraries.append('torch')
964
    libraries.append('torch_cpu')
965
    libraries.append('torch_python')
966
    if IS_WINDOWS:
967
        libraries.append("sleef")
968

969
    kwargs['libraries'] = libraries
970

971
    kwargs['language'] = 'c++'
972
    return setuptools.Extension(name, sources, *args, **kwargs)
973

974

975
def CUDAExtension(name, sources, *args, **kwargs):
976
    """
977
    Create a :class:`setuptools.Extension` for CUDA/C++.
978

979
    Convenience method that creates a :class:`setuptools.Extension` with the
980
    bare minimum (but often sufficient) arguments to build a CUDA/C++
981
    extension. This includes the CUDA include path, library path and runtime
982
    library.
983

984
    All arguments are forwarded to the :class:`setuptools.Extension`
985
    constructor. Full list arguments can be found at
986
    https://setuptools.pypa.io/en/latest/userguide/ext_modules.html#extension-api-reference
987

988
    Example:
989
        >>> # xdoctest: +SKIP
990
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT)
991
        >>> from setuptools import setup
992
        >>> from torch.utils.cpp_extension import BuildExtension, CUDAExtension
993
        >>> setup(
994
        ...     name='cuda_extension',
995
        ...     ext_modules=[
996
        ...         CUDAExtension(
997
        ...                 name='cuda_extension',
998
        ...                 sources=['extension.cpp', 'extension_kernel.cu'],
999
        ...                 extra_compile_args={'cxx': ['-g'],
1000
        ...                                     'nvcc': ['-O2']},
1001
        ...                 extra_link_args=['-Wl,--no-as-needed', '-lcuda'])
1002
        ...     ],
1003
        ...     cmdclass={
1004
        ...         'build_ext': BuildExtension
1005
        ...     })
1006

1007
    Compute capabilities:
1008

1009
    By default the extension will be compiled to run on all archs of the cards visible during the
1010
    building process of the extension, plus PTX. If down the road a new card is installed the
1011
    extension may need to be recompiled. If a visible card has a compute capability (CC) that's
1012
    newer than the newest version for which your nvcc can build fully-compiled binaries, Pytorch
1013
    will make nvcc fall back to building kernels with the newest version of PTX your nvcc does
1014
    support (see below for details on PTX).
1015

1016
    You can override the default behavior using `TORCH_CUDA_ARCH_LIST` to explicitly specify which
1017
    CCs you want the extension to support:
1018

1019
    ``TORCH_CUDA_ARCH_LIST="6.1 8.6" python build_my_extension.py``
1020
    ``TORCH_CUDA_ARCH_LIST="5.2 6.0 6.1 7.0 7.5 8.0 8.6+PTX" python build_my_extension.py``
1021

1022
    The +PTX option causes extension kernel binaries to include PTX instructions for the specified
1023
    CC. PTX is an intermediate representation that allows kernels to runtime-compile for any CC >=
1024
    the specified CC (for example, 8.6+PTX generates PTX that can runtime-compile for any GPU with
1025
    CC >= 8.6). This improves your binary's forward compatibility. However, relying on older PTX to
1026
    provide forward compat by runtime-compiling for newer CCs can modestly reduce performance on
1027
    those newer CCs. If you know exact CC(s) of the GPUs you want to target, you're always better
1028
    off specifying them individually. For example, if you want your extension to run on 8.0 and 8.6,
1029
    "8.0+PTX" would work functionally because it includes PTX that can runtime-compile for 8.6, but
1030
    "8.0 8.6" would be better.
1031

1032
    Note that while it's possible to include all supported archs, the more archs get included the
1033
    slower the building process will be, as it will build a separate kernel image for each arch.
1034

1035
    Note that CUDA-11.5 nvcc will hit internal compiler error while parsing torch/extension.h on Windows.
1036
    To workaround the issue, move python binding logic to pure C++ file.
1037

1038
    Example use:
1039
        #include <ATen/ATen.h>
1040
        at::Tensor SigmoidAlphaBlendForwardCuda(....)
1041

1042
    Instead of:
1043
        #include <torch/extension.h>
1044
        torch::Tensor SigmoidAlphaBlendForwardCuda(...)
1045

1046
    Currently open issue for nvcc bug: https://github.com/pytorch/pytorch/issues/69460
1047
    Complete workaround code example: https://github.com/facebookresearch/pytorch3d/commit/cb170ac024a949f1f9614ffe6af1c38d972f7d48
1048

1049
    Relocatable device code linking:
1050

1051
    If you want to reference device symbols across compilation units (across object files),
1052
    the object files need to be built with `relocatable device code` (-rdc=true or -dc).
1053
    An exception to this rule is "dynamic parallelism" (nested kernel launches)  which is not used a lot anymore.
1054
    `Relocatable device code` is less optimized so it needs to be used only on object files that need it.
1055
    Using `-dlto` (Device Link Time Optimization) at the device code compilation step and `dlink` step
1056
    help reduce the protentional perf degradation of `-rdc`.
1057
    Note that it needs to be used at both steps to be useful.
1058

1059
    If you have `rdc` objects you need to have an extra `-dlink` (device linking) step before the CPU symbol linking step.
1060
    There is also a case where `-dlink` is used without `-rdc`:
1061
    when an extension is linked against a static lib containing rdc-compiled objects
1062
    like the [NVSHMEM library](https://developer.nvidia.com/nvshmem).
1063

1064
    Note: Ninja is required to build a CUDA Extension with RDC linking.
1065

1066
    Example:
1067
        >>> # xdoctest: +SKIP
1068
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT)
1069
        >>> CUDAExtension(
1070
        ...        name='cuda_extension',
1071
        ...        sources=['extension.cpp', 'extension_kernel.cu'],
1072
        ...        dlink=True,
1073
        ...        dlink_libraries=["dlink_lib"],
1074
        ...        extra_compile_args={'cxx': ['-g'],
1075
        ...                            'nvcc': ['-O2', '-rdc=true']})
1076
    """
1077
    library_dirs = kwargs.get('library_dirs', [])
1078
    library_dirs += library_paths(cuda=True)
1079
    kwargs['library_dirs'] = library_dirs
1080

1081
    libraries = kwargs.get('libraries', [])
1082
    libraries.append('c10')
1083
    libraries.append('torch')
1084
    libraries.append('torch_cpu')
1085
    libraries.append('torch_python')
1086
    if IS_HIP_EXTENSION:
1087
        libraries.append('amdhip64')
1088
        libraries.append('c10_hip')
1089
        libraries.append('torch_hip')
1090
    else:
1091
        libraries.append('cudart')
1092
        libraries.append('c10_cuda')
1093
        libraries.append('torch_cuda')
1094
    kwargs['libraries'] = libraries
1095

1096
    include_dirs = kwargs.get('include_dirs', [])
1097

1098
    if IS_HIP_EXTENSION:
1099
        build_dir = os.getcwd()
1100
        hipify_result = hipify_python.hipify(
1101
            project_directory=build_dir,
1102
            output_directory=build_dir,
1103
            header_include_dirs=include_dirs,
1104
            includes=[os.path.join(build_dir, '*')],  # limit scope to build_dir only
1105
            extra_files=[os.path.abspath(s) for s in sources],
1106
            show_detailed=True,
1107
            is_pytorch_extension=True,
1108
            hipify_extra_files_only=True,  # don't hipify everything in includes path
1109
        )
1110

1111
        hipified_sources = set()
1112
        for source in sources:
1113
            s_abs = os.path.abspath(source)
1114
            hipified_s_abs = (hipify_result[s_abs].hipified_path if (s_abs in hipify_result and
1115
                              hipify_result[s_abs].hipified_path is not None) else s_abs)
1116
            # setup() arguments must *always* be /-separated paths relative to the setup.py directory,
1117
            # *never* absolute paths
1118
            hipified_sources.add(os.path.relpath(hipified_s_abs, build_dir))
1119

1120
        sources = list(hipified_sources)
1121

1122
    include_dirs += include_paths(cuda=True)
1123
    kwargs['include_dirs'] = include_dirs
1124

1125
    kwargs['language'] = 'c++'
1126

1127
    dlink_libraries = kwargs.get('dlink_libraries', [])
1128
    dlink = kwargs.get('dlink', False) or dlink_libraries
1129
    if dlink:
1130
        extra_compile_args = kwargs.get('extra_compile_args', {})
1131

1132
        extra_compile_args_dlink = extra_compile_args.get('nvcc_dlink', [])
1133
        extra_compile_args_dlink += ['-dlink']
1134
        extra_compile_args_dlink += [f'-L{x}' for x in library_dirs]
1135
        extra_compile_args_dlink += [f'-l{x}' for x in dlink_libraries]
1136

1137
        if (torch.version.cuda is not None) and TorchVersion(torch.version.cuda) >= '11.2':
1138
            extra_compile_args_dlink += ['-dlto']   # Device Link Time Optimization started from cuda 11.2
1139

1140
        extra_compile_args['nvcc_dlink'] = extra_compile_args_dlink
1141

1142
        kwargs['extra_compile_args'] = extra_compile_args
1143

1144
    return setuptools.Extension(name, sources, *args, **kwargs)
1145

1146

1147
def include_paths(cuda: bool = False) -> List[str]:
1148
    """
1149
    Get the include paths required to build a C++ or CUDA extension.
1150

1151
    Args:
1152
        cuda: If `True`, includes CUDA-specific include paths.
1153

1154
    Returns:
1155
        A list of include path strings.
1156
    """
1157
    lib_include = os.path.join(_TORCH_PATH, 'include')
1158
    paths = [
1159
        lib_include,
1160
        # Remove this once torch/torch.h is officially no longer supported for C++ extensions.
1161
        os.path.join(lib_include, 'torch', 'csrc', 'api', 'include'),
1162
        # Some internal (old) Torch headers don't properly prefix their includes,
1163
        # so we need to pass -Itorch/lib/include/TH as well.
1164
        os.path.join(lib_include, 'TH'),
1165
        os.path.join(lib_include, 'THC')
1166
    ]
1167
    if cuda and IS_HIP_EXTENSION:
1168
        paths.append(os.path.join(lib_include, 'THH'))
1169
        paths.append(_join_rocm_home('include'))
1170
    elif cuda:
1171
        cuda_home_include = _join_cuda_home('include')
1172
        # if we have the Debian/Ubuntu packages for cuda, we get /usr as cuda home.
1173
        # but gcc doesn't like having /usr/include passed explicitly
1174
        if cuda_home_include != '/usr/include':
1175
            paths.append(cuda_home_include)
1176

1177
        # Support CUDA_INC_PATH env variable supported by CMake files
1178
        if (cuda_inc_path := os.environ.get("CUDA_INC_PATH", None)) and \
1179
                cuda_inc_path != '/usr/include':
1180
            paths.append(cuda_inc_path)
1181
        if CUDNN_HOME is not None:
1182
            paths.append(os.path.join(CUDNN_HOME, 'include'))
1183
    return paths
1184

1185

1186
def library_paths(cuda: bool = False) -> List[str]:
1187
    """
1188
    Get the library paths required to build a C++ or CUDA extension.
1189

1190
    Args:
1191
        cuda: If `True`, includes CUDA-specific library paths.
1192

1193
    Returns:
1194
        A list of library path strings.
1195
    """
1196
    # We need to link against libtorch.so
1197
    paths = [TORCH_LIB_PATH]
1198

1199
    if cuda and IS_HIP_EXTENSION:
1200
        lib_dir = 'lib'
1201
        paths.append(_join_rocm_home(lib_dir))
1202
        if HIP_HOME is not None:
1203
            paths.append(os.path.join(HIP_HOME, 'lib'))
1204
    elif cuda:
1205
        if IS_WINDOWS:
1206
            lib_dir = os.path.join('lib', 'x64')
1207
        else:
1208
            lib_dir = 'lib64'
1209
            if (not os.path.exists(_join_cuda_home(lib_dir)) and
1210
                    os.path.exists(_join_cuda_home('lib'))):
1211
                # 64-bit CUDA may be installed in 'lib' (see e.g. gh-16955)
1212
                # Note that it's also possible both don't exist (see
1213
                # _find_cuda_home) - in that case we stay with 'lib64'.
1214
                lib_dir = 'lib'
1215

1216
        paths.append(_join_cuda_home(lib_dir))
1217
        if CUDNN_HOME is not None:
1218
            paths.append(os.path.join(CUDNN_HOME, lib_dir))
1219
    return paths
1220

1221

1222
def load(name,
1223
         sources: Union[str, List[str]],
1224
         extra_cflags=None,
1225
         extra_cuda_cflags=None,
1226
         extra_ldflags=None,
1227
         extra_include_paths=None,
1228
         build_directory=None,
1229
         verbose=False,
1230
         with_cuda: Optional[bool] = None,
1231
         is_python_module=True,
1232
         is_standalone=False,
1233
         keep_intermediates=True):
1234
    """
1235
    Load a PyTorch C++ extension just-in-time (JIT).
1236

1237
    To load an extension, a Ninja build file is emitted, which is used to
1238
    compile the given sources into a dynamic library. This library is
1239
    subsequently loaded into the current Python process as a module and
1240
    returned from this function, ready for use.
1241

1242
    By default, the directory to which the build file is emitted and the
1243
    resulting library compiled to is ``<tmp>/torch_extensions/<name>``, where
1244
    ``<tmp>`` is the temporary folder on the current platform and ``<name>``
1245
    the name of the extension. This location can be overridden in two ways.
1246
    First, if the ``TORCH_EXTENSIONS_DIR`` environment variable is set, it
1247
    replaces ``<tmp>/torch_extensions`` and all extensions will be compiled
1248
    into subfolders of this directory. Second, if the ``build_directory``
1249
    argument to this function is supplied, it overrides the entire path, i.e.
1250
    the library will be compiled into that folder directly.
1251

1252
    To compile the sources, the default system compiler (``c++``) is used,
1253
    which can be overridden by setting the ``CXX`` environment variable. To pass
1254
    additional arguments to the compilation process, ``extra_cflags`` or
1255
    ``extra_ldflags`` can be provided. For example, to compile your extension
1256
    with optimizations, pass ``extra_cflags=['-O3']``. You can also use
1257
    ``extra_cflags`` to pass further include directories.
1258

1259
    CUDA support with mixed compilation is provided. Simply pass CUDA source
1260
    files (``.cu`` or ``.cuh``) along with other sources. Such files will be
1261
    detected and compiled with nvcc rather than the C++ compiler. This includes
1262
    passing the CUDA lib64 directory as a library directory, and linking
1263
    ``cudart``. You can pass additional flags to nvcc via
1264
    ``extra_cuda_cflags``, just like with ``extra_cflags`` for C++. Various
1265
    heuristics for finding the CUDA install directory are used, which usually
1266
    work fine. If not, setting the ``CUDA_HOME`` environment variable is the
1267
    safest option.
1268

1269
    Args:
1270
        name: The name of the extension to build. This MUST be the same as the
1271
            name of the pybind11 module!
1272
        sources: A list of relative or absolute paths to C++ source files.
1273
        extra_cflags: optional list of compiler flags to forward to the build.
1274
        extra_cuda_cflags: optional list of compiler flags to forward to nvcc
1275
            when building CUDA sources.
1276
        extra_ldflags: optional list of linker flags to forward to the build.
1277
        extra_include_paths: optional list of include directories to forward
1278
            to the build.
1279
        build_directory: optional path to use as build workspace.
1280
        verbose: If ``True``, turns on verbose logging of load steps.
1281
        with_cuda: Determines whether CUDA headers and libraries are added to
1282
            the build. If set to ``None`` (default), this value is
1283
            automatically determined based on the existence of ``.cu`` or
1284
            ``.cuh`` in ``sources``. Set it to `True`` to force CUDA headers
1285
            and libraries to be included.
1286
        is_python_module: If ``True`` (default), imports the produced shared
1287
            library as a Python module. If ``False``, behavior depends on
1288
            ``is_standalone``.
1289
        is_standalone: If ``False`` (default) loads the constructed extension
1290
            into the process as a plain dynamic library. If ``True``, build a
1291
            standalone executable.
1292

1293
    Returns:
1294
        If ``is_python_module`` is ``True``:
1295
            Returns the loaded PyTorch extension as a Python module.
1296

1297
        If ``is_python_module`` is ``False`` and ``is_standalone`` is ``False``:
1298
            Returns nothing. (The shared library is loaded into the process as
1299
            a side effect.)
1300

1301
        If ``is_standalone`` is ``True``.
1302
            Return the path to the executable. (On Windows, TORCH_LIB_PATH is
1303
            added to the PATH environment variable as a side effect.)
1304

1305
    Example:
1306
        >>> # xdoctest: +SKIP
1307
        >>> from torch.utils.cpp_extension import load
1308
        >>> module = load(
1309
        ...     name='extension',
1310
        ...     sources=['extension.cpp', 'extension_kernel.cu'],
1311
        ...     extra_cflags=['-O2'],
1312
        ...     verbose=True)
1313
    """
1314
    return _jit_compile(
1315
        name,
1316
        [sources] if isinstance(sources, str) else sources,
1317
        extra_cflags,
1318
        extra_cuda_cflags,
1319
        extra_ldflags,
1320
        extra_include_paths,
1321
        build_directory or _get_build_directory(name, verbose),
1322
        verbose,
1323
        with_cuda,
1324
        is_python_module,
1325
        is_standalone,
1326
        keep_intermediates=keep_intermediates)
1327

1328
def _get_pybind11_abi_build_flags():
1329
    # Note [Pybind11 ABI constants]
1330
    #
1331
    # Pybind11 before 2.4 used to build an ABI strings using the following pattern:
1332
    # f"__pybind11_internals_v{PYBIND11_INTERNALS_VERSION}{PYBIND11_INTERNALS_KIND}{PYBIND11_BUILD_TYPE}__"
1333
    # Since 2.4 compier type, stdlib and build abi parameters are also encoded like this:
1334
    # f"__pybind11_internals_v{PYBIND11_INTERNALS_VERSION}{PYBIND11_INTERNALS_KIND}{PYBIND11_COMPILER_TYPE}{PYBIND11_STDLIB}{PYBIND11_BUILD_ABI}{PYBIND11_BUILD_TYPE}__"
1335
    #
1336
    # This was done in order to further narrow down the chances of compiler ABI incompatibility
1337
    # that can cause a hard to debug segfaults.
1338
    # For PyTorch extensions we want to relax those restrictions and pass compiler, stdlib and abi properties
1339
    # captured during PyTorch native library compilation in torch/csrc/Module.cpp
1340

1341
    abi_cflags = []
1342
    for pname in ["COMPILER_TYPE", "STDLIB", "BUILD_ABI"]:
1343
        pval = getattr(torch._C, f"_PYBIND11_{pname}")
1344
        if pval is not None and not IS_WINDOWS:
1345
            abi_cflags.append(f'-DPYBIND11_{pname}=\\"{pval}\\"')
1346
    return abi_cflags
1347

1348
def _get_glibcxx_abi_build_flags():
1349
    glibcxx_abi_cflags = ['-D_GLIBCXX_USE_CXX11_ABI=' + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))]
1350
    return glibcxx_abi_cflags
1351

1352
def check_compiler_is_gcc(compiler):
1353
    if not IS_LINUX:
1354
        return False
1355

1356
    env = os.environ.copy()
1357
    env['LC_ALL'] = 'C'  # Don't localize output
1358
    try:
1359
        version_string = subprocess.check_output([compiler, '-v'], stderr=subprocess.STDOUT, env=env).decode(*SUBPROCESS_DECODE_ARGS)
1360
    except Exception as e:
1361
        try:
1362
            version_string = subprocess.check_output([compiler, '--version'], stderr=subprocess.STDOUT, env=env).decode(*SUBPROCESS_DECODE_ARGS)
1363
        except Exception as e:
1364
            return False
1365
    # Check for 'gcc' or 'g++' for sccache wrapper
1366
    pattern = re.compile("^COLLECT_GCC=(.*)$", re.MULTILINE)
1367
    results = re.findall(pattern, version_string)
1368
    if len(results) != 1:
1369
        return False
1370
    compiler_path = os.path.realpath(results[0].strip())
1371
    # On RHEL/CentOS c++ is a gcc compiler wrapper
1372
    if os.path.basename(compiler_path) == 'c++' and 'gcc version' in version_string:
1373
        return True
1374
    return False
1375

1376
def _check_and_build_extension_h_precompiler_headers(
1377
        extra_cflags,
1378
        extra_include_paths,
1379
        is_standalone=False):
1380
    r'''
1381
    Precompiled Headers(PCH) can pre-build the same headers and reduce build time for pytorch load_inline modules.
1382
    GCC offical manual: https://gcc.gnu.org/onlinedocs/gcc-4.0.4/gcc/Precompiled-Headers.html
1383
    PCH only works when built pch file(header.h.gch) and build target have the same build parameters. So, We need
1384
    add a signature file to record PCH file parameters. If the build parameters(signature) changed, it should rebuild
1385
    PCH file.
1386

1387
    Note:
1388
    1. Windows and MacOS have different PCH mechanism. We only support Linux currently.
1389
    2. It only works on GCC/G++.
1390
    '''
1391
    if not IS_LINUX:
1392
        return
1393

1394
    compiler = get_cxx_compiler()
1395

1396
    b_is_gcc = check_compiler_is_gcc(compiler)
1397
    if b_is_gcc is False:
1398
        return
1399

1400
    head_file = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h')
1401
    head_file_pch = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h.gch')
1402
    head_file_signature = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h.sign')
1403

1404
    def listToString(s):
1405
        # initialize an empty string
1406
        string = ""
1407
        if s is None:
1408
            return string
1409

1410
        # traverse in the string
1411
        for element in s:
1412
            string += (element + ' ')
1413
        # return string
1414
        return string
1415

1416
    def format_precompiler_header_cmd(compiler, head_file, head_file_pch, common_cflags, torch_include_dirs, extra_cflags, extra_include_paths):
1417
        return re.sub(
1418
            r"[ \n]+",
1419
            " ",
1420
            f"""
1421
                {compiler} -x c++-header {head_file} -o {head_file_pch} {torch_include_dirs} {extra_include_paths} {extra_cflags} {common_cflags}
1422
            """,
1423
        ).strip()
1424

1425
    def command_to_signature(cmd):
1426
        signature = cmd.replace(' ', '_')
1427
        return signature
1428

1429
    def check_pch_signature_in_file(file_path, signature):
1430
        b_exist = os.path.isfile(file_path)
1431
        if b_exist is False:
1432
            return False
1433

1434
        with open(file_path) as file:
1435
            # read all content of a file
1436
            content = file.read()
1437
            # check if string present in a file
1438
            return signature == content
1439

1440
    def _create_if_not_exist(path_dir):
1441
        if not os.path.exists(path_dir):
1442
            try:
1443
                Path(path_dir).mkdir(parents=True, exist_ok=True)
1444
            except OSError as exc:  # Guard against race condition
1445
                if exc.errno != errno.EEXIST:
1446
                    raise RuntimeError(f"Fail to create path {path_dir}") from exc
1447

1448
    def write_pch_signature_to_file(file_path, pch_sign):
1449
        _create_if_not_exist(os.path.dirname(file_path))
1450
        with open(file_path, "w") as f:
1451
            f.write(pch_sign)
1452
            f.close()
1453

1454
    def build_precompile_header(pch_cmd):
1455
        try:
1456
            subprocess.check_output(pch_cmd, shell=True, stderr=subprocess.STDOUT)
1457
        except subprocess.CalledProcessError as e:
1458
            raise RuntimeError(f"Compile PreCompile Header fail, command: {pch_cmd}") from e
1459

1460
    extra_cflags_str = listToString(extra_cflags)
1461
    extra_include_paths_str = " ".join(
1462
        [f"-I{include}" for include in extra_include_paths] if extra_include_paths else []
1463
    )
1464

1465
    lib_include = os.path.join(_TORCH_PATH, 'include')
1466
    torch_include_dirs = [
1467
        f"-I {lib_include}",
1468
        # Python.h
1469
        "-I {}".format(sysconfig.get_path("include")),
1470
        # torch/all.h
1471
        "-I {}".format(os.path.join(lib_include, 'torch', 'csrc', 'api', 'include')),
1472
    ]
1473

1474
    torch_include_dirs_str = listToString(torch_include_dirs)
1475

1476
    common_cflags = []
1477
    if not is_standalone:
1478
        common_cflags += ['-DTORCH_API_INCLUDE_EXTENSION_H']
1479

1480
    common_cflags += ['-std=c++17', '-fPIC']
1481
    common_cflags += [f"{x}" for x in _get_pybind11_abi_build_flags()]
1482
    common_cflags += [f"{x}" for x in _get_glibcxx_abi_build_flags()]
1483
    common_cflags_str = listToString(common_cflags)
1484

1485
    pch_cmd = format_precompiler_header_cmd(compiler, head_file, head_file_pch, common_cflags_str, torch_include_dirs_str, extra_cflags_str, extra_include_paths_str)
1486
    pch_sign = command_to_signature(pch_cmd)
1487

1488
    if os.path.isfile(head_file_pch) is not True:
1489
        build_precompile_header(pch_cmd)
1490
        write_pch_signature_to_file(head_file_signature, pch_sign)
1491
    else:
1492
        b_same_sign = check_pch_signature_in_file(head_file_signature, pch_sign)
1493
        if b_same_sign is False:
1494
            build_precompile_header(pch_cmd)
1495
            write_pch_signature_to_file(head_file_signature, pch_sign)
1496

1497
def remove_extension_h_precompiler_headers():
1498
    def _remove_if_file_exists(path_file):
1499
        if os.path.exists(path_file):
1500
            os.remove(path_file)
1501

1502
    head_file_pch = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h.gch')
1503
    head_file_signature = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h.sign')
1504

1505
    _remove_if_file_exists(head_file_pch)
1506
    _remove_if_file_exists(head_file_signature)
1507

1508
def load_inline(name,
1509
                cpp_sources,
1510
                cuda_sources=None,
1511
                functions=None,
1512
                extra_cflags=None,
1513
                extra_cuda_cflags=None,
1514
                extra_ldflags=None,
1515
                extra_include_paths=None,
1516
                build_directory=None,
1517
                verbose=False,
1518
                with_cuda=None,
1519
                is_python_module=True,
1520
                with_pytorch_error_handling=True,
1521
                keep_intermediates=True,
1522
                use_pch=False):
1523
    r'''
1524
    Load a PyTorch C++ extension just-in-time (JIT) from string sources.
1525

1526
    This function behaves exactly like :func:`load`, but takes its sources as
1527
    strings rather than filenames. These strings are stored to files in the
1528
    build directory, after which the behavior of :func:`load_inline` is
1529
    identical to :func:`load`.
1530

1531
    See `the
1532
    tests <https://github.com/pytorch/pytorch/blob/master/test/test_cpp_extensions_jit.py>`_
1533
    for good examples of using this function.
1534

1535
    Sources may omit two required parts of a typical non-inline C++ extension:
1536
    the necessary header includes, as well as the (pybind11) binding code. More
1537
    precisely, strings passed to ``cpp_sources`` are first concatenated into a
1538
    single ``.cpp`` file. This file is then prepended with ``#include
1539
    <torch/extension.h>``.
1540

1541
    Furthermore, if the ``functions`` argument is supplied, bindings will be
1542
    automatically generated for each function specified. ``functions`` can
1543
    either be a list of function names, or a dictionary mapping from function
1544
    names to docstrings. If a list is given, the name of each function is used
1545
    as its docstring.
1546

1547
    The sources in ``cuda_sources`` are concatenated into a separate ``.cu``
1548
    file and  prepended with ``torch/types.h``, ``cuda.h`` and
1549
    ``cuda_runtime.h`` includes. The ``.cpp`` and ``.cu`` files are compiled
1550
    separately, but ultimately linked into a single library. Note that no
1551
    bindings are generated for functions in ``cuda_sources`` per  se. To bind
1552
    to a CUDA kernel, you must create a C++ function that calls it, and either
1553
    declare or define this C++ function in one of the ``cpp_sources`` (and
1554
    include its name in ``functions``).
1555

1556
    See :func:`load` for a description of arguments omitted below.
1557

1558
    Args:
1559
        cpp_sources: A string, or list of strings, containing C++ source code.
1560
        cuda_sources: A string, or list of strings, containing CUDA source code.
1561
        functions: A list of function names for which to generate function
1562
            bindings. If a dictionary is given, it should map function names to
1563
            docstrings (which are otherwise just the function names).
1564
        with_cuda: Determines whether CUDA headers and libraries are added to
1565
            the build. If set to ``None`` (default), this value is
1566
            automatically determined based on whether ``cuda_sources`` is
1567
            provided. Set it to ``True`` to force CUDA headers
1568
            and libraries to be included.
1569
        with_pytorch_error_handling: Determines whether pytorch error and
1570
            warning macros are handled by pytorch instead of pybind. To do
1571
            this, each function ``foo`` is called via an intermediary ``_safe_foo``
1572
            function. This redirection might cause issues in obscure cases
1573
            of cpp. This flag should be set to ``False`` when this redirect
1574
            causes issues.
1575

1576
    Example:
1577
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT)
1578
        >>> from torch.utils.cpp_extension import load_inline
1579
        >>> source = """
1580
        at::Tensor sin_add(at::Tensor x, at::Tensor y) {
1581
          return x.sin() + y.sin();
1582
        }
1583
        """
1584
        >>> module = load_inline(name='inline_extension',
1585
        ...                      cpp_sources=[source],
1586
        ...                      functions=['sin_add'])
1587

1588
    .. note::
1589
        By default, the Ninja backend uses #CPUS + 2 workers to build the
1590
        extension. This may use up too many resources on some systems. One
1591
        can control the number of workers by setting the `MAX_JOBS` environment
1592
        variable to a non-negative number.
1593
    '''
1594
    build_directory = build_directory or _get_build_directory(name, verbose)
1595

1596
    if isinstance(cpp_sources, str):
1597
        cpp_sources = [cpp_sources]
1598
    cuda_sources = cuda_sources or []
1599
    if isinstance(cuda_sources, str):
1600
        cuda_sources = [cuda_sources]
1601

1602
    cpp_sources.insert(0, '#include <torch/extension.h>')
1603

1604
    if use_pch is True:
1605
        # Using PreCompile Header('torch/extension.h') to reduce compile time.
1606
        _check_and_build_extension_h_precompiler_headers(extra_cflags, extra_include_paths)
1607
    else:
1608
        remove_extension_h_precompiler_headers()
1609

1610
    # If `functions` is supplied, we create the pybind11 bindings for the user.
1611
    # Here, `functions` is (or becomes, after some processing) a map from
1612
    # function names to function docstrings.
1613
    if functions is not None:
1614
        module_def = []
1615
        module_def.append('PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {')
1616
        if isinstance(functions, str):
1617
            functions = [functions]
1618
        if isinstance(functions, list):
1619
            # Make the function docstring the same as the function name.
1620
            functions = {f: f for f in functions}
1621
        elif not isinstance(functions, dict):
1622
            raise ValueError(f"Expected 'functions' to be a list or dict, but was {type(functions)}")
1623
        for function_name, docstring in functions.items():
1624
            if with_pytorch_error_handling:
1625
                module_def.append(f'm.def("{function_name}", torch::wrap_pybind_function({function_name}), "{docstring}");')
1626
            else:
1627
                module_def.append(f'm.def("{function_name}", {function_name}, "{docstring}");')
1628
        module_def.append('}')
1629
        cpp_sources += module_def
1630

1631
    cpp_source_path = os.path.join(build_directory, 'main.cpp')
1632
    _maybe_write(cpp_source_path, "\n".join(cpp_sources))
1633

1634
    sources = [cpp_source_path]
1635

1636
    if cuda_sources:
1637
        cuda_sources.insert(0, '#include <torch/types.h>')
1638
        cuda_sources.insert(1, '#include <cuda.h>')
1639
        cuda_sources.insert(2, '#include <cuda_runtime.h>')
1640

1641
        cuda_source_path = os.path.join(build_directory, 'cuda.cu')
1642
        _maybe_write(cuda_source_path, "\n".join(cuda_sources))
1643

1644
        sources.append(cuda_source_path)
1645

1646
    return _jit_compile(
1647
        name,
1648
        sources,
1649
        extra_cflags,
1650
        extra_cuda_cflags,
1651
        extra_ldflags,
1652
        extra_include_paths,
1653
        build_directory,
1654
        verbose,
1655
        with_cuda,
1656
        is_python_module,
1657
        is_standalone=False,
1658
        keep_intermediates=keep_intermediates)
1659

1660

1661
def _jit_compile(name,
1662
                 sources,
1663
                 extra_cflags,
1664
                 extra_cuda_cflags,
1665
                 extra_ldflags,
1666
                 extra_include_paths,
1667
                 build_directory: str,
1668
                 verbose: bool,
1669
                 with_cuda: Optional[bool],
1670
                 is_python_module,
1671
                 is_standalone,
1672
                 keep_intermediates=True) -> None:
1673
    if is_python_module and is_standalone:
1674
        raise ValueError("`is_python_module` and `is_standalone` are mutually exclusive.")
1675

1676
    if with_cuda is None:
1677
        with_cuda = any(map(_is_cuda_file, sources))
1678
    with_cudnn = any('cudnn' in f for f in extra_ldflags or [])
1679
    old_version = JIT_EXTENSION_VERSIONER.get_version(name)
1680
    version = JIT_EXTENSION_VERSIONER.bump_version_if_changed(
1681
        name,
1682
        sources,
1683
        build_arguments=[extra_cflags, extra_cuda_cflags, extra_ldflags, extra_include_paths],
1684
        build_directory=build_directory,
1685
        with_cuda=with_cuda,
1686
        is_python_module=is_python_module,
1687
        is_standalone=is_standalone,
1688
    )
1689
    if version > 0:
1690
        if version != old_version and verbose:
1691
            print(f'The input conditions for extension module {name} have changed. ' +
1692
                  f'Bumping to version {version} and re-building as {name}_v{version}...',
1693
                  file=sys.stderr)
1694
        name = f'{name}_v{version}'
1695

1696
    baton = FileBaton(os.path.join(build_directory, 'lock'))
1697
    if baton.try_acquire():
1698
        try:
1699
            if version != old_version:
1700
                with GeneratedFileCleaner(keep_intermediates=keep_intermediates) as clean_ctx:
1701
                    if IS_HIP_EXTENSION and (with_cuda or with_cudnn):
1702
                        hipify_result = hipify_python.hipify(
1703
                            project_directory=build_directory,
1704
                            output_directory=build_directory,
1705
                            header_include_dirs=(extra_include_paths if extra_include_paths is not None else []),
1706
                            extra_files=[os.path.abspath(s) for s in sources],
1707
                            ignores=[_join_rocm_home('*'), os.path.join(_TORCH_PATH, '*')],  # no need to hipify ROCm or PyTorch headers
1708
                            show_detailed=verbose,
1709
                            show_progress=verbose,
1710
                            is_pytorch_extension=True,
1711
                            clean_ctx=clean_ctx
1712
                        )
1713

1714
                        hipified_sources = set()
1715
                        for source in sources:
1716
                            s_abs = os.path.abspath(source)
1717
                            hipified_sources.add(hipify_result[s_abs].hipified_path if s_abs in hipify_result else s_abs)
1718

1719
                        sources = list(hipified_sources)
1720

1721
                    _write_ninja_file_and_build_library(
1722
                        name=name,
1723
                        sources=sources,
1724
                        extra_cflags=extra_cflags or [],
1725
                        extra_cuda_cflags=extra_cuda_cflags or [],
1726
                        extra_ldflags=extra_ldflags or [],
1727
                        extra_include_paths=extra_include_paths or [],
1728
                        build_directory=build_directory,
1729
                        verbose=verbose,
1730
                        with_cuda=with_cuda,
1731
                        is_standalone=is_standalone)
1732
            elif verbose:
1733
                print('No modifications detected for re-loaded extension '
1734
                      f'module {name}, skipping build step...', file=sys.stderr)
1735
        finally:
1736
            baton.release()
1737
    else:
1738
        baton.wait()
1739

1740
    if verbose:
1741
        print(f'Loading extension module {name}...', file=sys.stderr)
1742

1743
    if is_standalone:
1744
        return _get_exec_path(name, build_directory)
1745

1746
    return _import_module_from_library(name, build_directory, is_python_module)
1747

1748

1749
def _write_ninja_file_and_compile_objects(
1750
        sources: List[str],
1751
        objects,
1752
        cflags,
1753
        post_cflags,
1754
        cuda_cflags,
1755
        cuda_post_cflags,
1756
        cuda_dlink_post_cflags,
1757
        build_directory: str,
1758
        verbose: bool,
1759
        with_cuda: Optional[bool]) -> None:
1760
    verify_ninja_availability()
1761

1762
    compiler = get_cxx_compiler()
1763

1764
    get_compiler_abi_compatibility_and_version(compiler)
1765
    if with_cuda is None:
1766
        with_cuda = any(map(_is_cuda_file, sources))
1767
    build_file_path = os.path.join(build_directory, 'build.ninja')
1768
    if verbose:
1769
        print(f'Emitting ninja build file {build_file_path}...', file=sys.stderr)
1770
    _write_ninja_file(
1771
        path=build_file_path,
1772
        cflags=cflags,
1773
        post_cflags=post_cflags,
1774
        cuda_cflags=cuda_cflags,
1775
        cuda_post_cflags=cuda_post_cflags,
1776
        cuda_dlink_post_cflags=cuda_dlink_post_cflags,
1777
        sources=sources,
1778
        objects=objects,
1779
        ldflags=None,
1780
        library_target=None,
1781
        with_cuda=with_cuda)
1782
    if verbose:
1783
        print('Compiling objects...', file=sys.stderr)
1784
    _run_ninja_build(
1785
        build_directory,
1786
        verbose,
1787
        # It would be better if we could tell users the name of the extension
1788
        # that failed to build but there isn't a good way to get it here.
1789
        error_prefix='Error compiling objects for extension')
1790

1791

1792
def _write_ninja_file_and_build_library(
1793
        name,
1794
        sources: List[str],
1795
        extra_cflags,
1796
        extra_cuda_cflags,
1797
        extra_ldflags,
1798
        extra_include_paths,
1799
        build_directory: str,
1800
        verbose: bool,
1801
        with_cuda: Optional[bool],
1802
        is_standalone: bool = False) -> None:
1803
    verify_ninja_availability()
1804

1805
    compiler = get_cxx_compiler()
1806

1807
    get_compiler_abi_compatibility_and_version(compiler)
1808
    if with_cuda is None:
1809
        with_cuda = any(map(_is_cuda_file, sources))
1810
    extra_ldflags = _prepare_ldflags(
1811
        extra_ldflags or [],
1812
        with_cuda,
1813
        verbose,
1814
        is_standalone)
1815
    build_file_path = os.path.join(build_directory, 'build.ninja')
1816
    if verbose:
1817
        print(f'Emitting ninja build file {build_file_path}...', file=sys.stderr)
1818
    # NOTE: Emitting a new ninja build file does not cause re-compilation if
1819
    # the sources did not change, so it's ok to re-emit (and it's fast).
1820
    _write_ninja_file_to_build_library(
1821
        path=build_file_path,
1822
        name=name,
1823
        sources=sources,
1824
        extra_cflags=extra_cflags or [],
1825
        extra_cuda_cflags=extra_cuda_cflags or [],
1826
        extra_ldflags=extra_ldflags or [],
1827
        extra_include_paths=extra_include_paths or [],
1828
        with_cuda=with_cuda,
1829
        is_standalone=is_standalone)
1830

1831
    if verbose:
1832
        print(f'Building extension module {name}...', file=sys.stderr)
1833
    _run_ninja_build(
1834
        build_directory,
1835
        verbose,
1836
        error_prefix=f"Error building extension '{name}'")
1837

1838

1839
def is_ninja_available():
1840
    """Return ``True`` if the `ninja <https://ninja-build.org/>`_ build system is available on the system, ``False`` otherwise."""
1841
    try:
1842
        subprocess.check_output('ninja --version'.split())
1843
    except Exception:
1844
        return False
1845
    else:
1846
        return True
1847

1848

1849
def verify_ninja_availability():
1850
    """Raise ``RuntimeError`` if `ninja <https://ninja-build.org/>`_ build system is not available on the system, does nothing otherwise."""
1851
    if not is_ninja_available():
1852
        raise RuntimeError("Ninja is required to load C++ extensions")
1853

1854

1855
def _prepare_ldflags(extra_ldflags, with_cuda, verbose, is_standalone):
1856
    if IS_WINDOWS:
1857
        python_lib_path = os.path.join(sys.base_exec_prefix, 'libs')
1858

1859
        extra_ldflags.append('c10.lib')
1860
        if with_cuda:
1861
            extra_ldflags.append('c10_cuda.lib')
1862
        extra_ldflags.append('torch_cpu.lib')
1863
        if with_cuda:
1864
            extra_ldflags.append('torch_cuda.lib')
1865
            # /INCLUDE is used to ensure torch_cuda is linked against in a project that relies on it.
1866
            # Related issue: https://github.com/pytorch/pytorch/issues/31611
1867
            extra_ldflags.append('-INCLUDE:?warp_size@cuda@at@@YAHXZ')
1868
        extra_ldflags.append('torch.lib')
1869
        extra_ldflags.append(f'/LIBPATH:{TORCH_LIB_PATH}')
1870
        if not is_standalone:
1871
            extra_ldflags.append('torch_python.lib')
1872
            extra_ldflags.append(f'/LIBPATH:{python_lib_path}')
1873

1874
    else:
1875
        extra_ldflags.append(f'-L{TORCH_LIB_PATH}')
1876
        extra_ldflags.append('-lc10')
1877
        if with_cuda:
1878
            extra_ldflags.append('-lc10_hip' if IS_HIP_EXTENSION else '-lc10_cuda')
1879
        extra_ldflags.append('-ltorch_cpu')
1880
        if with_cuda:
1881
            extra_ldflags.append('-ltorch_hip' if IS_HIP_EXTENSION else '-ltorch_cuda')
1882
        extra_ldflags.append('-ltorch')
1883
        if not is_standalone:
1884
            extra_ldflags.append('-ltorch_python')
1885

1886
        if is_standalone:
1887
            extra_ldflags.append(f"-Wl,-rpath,{TORCH_LIB_PATH}")
1888

1889
    if with_cuda:
1890
        if verbose:
1891
            print('Detected CUDA files, patching ldflags', file=sys.stderr)
1892
        if IS_WINDOWS:
1893
            extra_ldflags.append(f'/LIBPATH:{_join_cuda_home("lib", "x64")}')
1894
            extra_ldflags.append('cudart.lib')
1895
            if CUDNN_HOME is not None:
1896
                extra_ldflags.append(f'/LIBPATH:{os.path.join(CUDNN_HOME, "lib", "x64")}')
1897
        elif not IS_HIP_EXTENSION:
1898
            extra_lib_dir = "lib64"
1899
            if (not os.path.exists(_join_cuda_home(extra_lib_dir)) and
1900
                    os.path.exists(_join_cuda_home("lib"))):
1901
                # 64-bit CUDA may be installed in "lib"
1902
                # Note that it's also possible both don't exist (see _find_cuda_home) - in that case we stay with "lib64"
1903
                extra_lib_dir = "lib"
1904
            extra_ldflags.append(f'-L{_join_cuda_home(extra_lib_dir)}')
1905
            extra_ldflags.append('-lcudart')
1906
            if CUDNN_HOME is not None:
1907
                extra_ldflags.append(f'-L{os.path.join(CUDNN_HOME, "lib64")}')
1908
        elif IS_HIP_EXTENSION:
1909
            extra_ldflags.append(f'-L{_join_rocm_home("lib")}')
1910
            extra_ldflags.append('-lamdhip64')
1911
    return extra_ldflags
1912

1913

1914
def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]:
1915
    """
1916
    Determine CUDA arch flags to use.
1917

1918
    For an arch, say "6.1", the added compile flag will be
1919
    ``-gencode=arch=compute_61,code=sm_61``.
1920
    For an added "+PTX", an additional
1921
    ``-gencode=arch=compute_xx,code=compute_xx`` is added.
1922

1923
    See select_compute_arch.cmake for corresponding named and supported arches
1924
    when building with CMake.
1925
    """
1926
    # If cflags is given, there may already be user-provided arch flags in it
1927
    # (from `extra_compile_args`)
1928
    if cflags is not None:
1929
        for flag in cflags:
1930
            if 'TORCH_EXTENSION_NAME' in flag:
1931
                continue
1932
            if 'arch' in flag:
1933
                return []
1934

1935
    # Note: keep combined names ("arch1+arch2") above single names, otherwise
1936
    # string replacement may not do the right thing
1937
    named_arches = collections.OrderedDict([
1938
        ('Kepler+Tesla', '3.7'),
1939
        ('Kepler', '3.5+PTX'),
1940
        ('Maxwell+Tegra', '5.3'),
1941
        ('Maxwell', '5.0;5.2+PTX'),
1942
        ('Pascal', '6.0;6.1+PTX'),
1943
        ('Volta+Tegra', '7.2'),
1944
        ('Volta', '7.0+PTX'),
1945
        ('Turing', '7.5+PTX'),
1946
        ('Ampere+Tegra', '8.7'),
1947
        ('Ampere', '8.0;8.6+PTX'),
1948
        ('Ada', '8.9+PTX'),
1949
        ('Hopper', '9.0+PTX'),
1950
    ])
1951

1952
    supported_arches = ['3.5', '3.7', '5.0', '5.2', '5.3', '6.0', '6.1', '6.2',
1953
                        '7.0', '7.2', '7.5', '8.0', '8.6', '8.7', '8.9', '9.0', '9.0a']
1954
    valid_arch_strings = supported_arches + [s + "+PTX" for s in supported_arches]
1955

1956
    # The default is sm_30 for CUDA 9.x and 10.x
1957
    # First check for an env var (same as used by the main setup.py)
1958
    # Can be one or more architectures, e.g. "6.1" or "3.5;5.2;6.0;6.1;7.0+PTX"
1959
    # See cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake
1960
    _arch_list = os.environ.get('TORCH_CUDA_ARCH_LIST', None)
1961

1962
    # If not given, determine what's best for the GPU / CUDA version that can be found
1963
    if not _arch_list:
1964
        warnings.warn(
1965
            "TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. \n"
1966
            "If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].")
1967
        arch_list = []
1968
        # the assumption is that the extension should run on any of the currently visible cards,
1969
        # which could be of different types - therefore all archs for visible cards should be included
1970
        for i in range(torch.cuda.device_count()):
1971
            capability = torch.cuda.get_device_capability(i)
1972
            supported_sm = [int(arch.split('_')[1])
1973
                            for arch in torch.cuda.get_arch_list() if 'sm_' in arch]
1974
            max_supported_sm = max((sm // 10, sm % 10) for sm in supported_sm)
1975
            # Capability of the device may be higher than what's supported by the user's
1976
            # NVCC, causing compilation error. User's NVCC is expected to match the one
1977
            # used to build pytorch, so we use the maximum supported capability of pytorch
1978
            # to clamp the capability.
1979
            capability = min(max_supported_sm, capability)
1980
            arch = f'{capability[0]}.{capability[1]}'
1981
            if arch not in arch_list:
1982
                arch_list.append(arch)
1983
        arch_list = sorted(arch_list)
1984
        arch_list[-1] += '+PTX'
1985
    else:
1986
        # Deal with lists that are ' ' separated (only deal with ';' after)
1987
        _arch_list = _arch_list.replace(' ', ';')
1988
        # Expand named arches
1989
        for named_arch, archval in named_arches.items():
1990
            _arch_list = _arch_list.replace(named_arch, archval)
1991

1992
        arch_list = _arch_list.split(';')
1993

1994
    flags = []
1995
    for arch in arch_list:
1996
        if arch not in valid_arch_strings:
1997
            raise ValueError(f"Unknown CUDA arch ({arch}) or GPU not supported")
1998
        else:
1999
            num = arch[0] + arch[2:].split("+")[0]
2000
            flags.append(f'-gencode=arch=compute_{num},code=sm_{num}')
2001
            if arch.endswith('+PTX'):
2002
                flags.append(f'-gencode=arch=compute_{num},code=compute_{num}')
2003

2004
    return sorted(set(flags))
2005

2006

2007
def _get_rocm_arch_flags(cflags: Optional[List[str]] = None) -> List[str]:
2008
    # If cflags is given, there may already be user-provided arch flags in it
2009
    # (from `extra_compile_args`)
2010
    if cflags is not None:
2011
        for flag in cflags:
2012
            if 'amdgpu-target' in flag or 'offload-arch' in flag:
2013
                return ['-fno-gpu-rdc']
2014
    # Use same defaults as used for building PyTorch
2015
    # Allow env var to override, just like during initial cmake build.
2016
    _archs = os.environ.get('PYTORCH_ROCM_ARCH', None)
2017
    if not _archs:
2018
        archFlags = torch._C._cuda_getArchFlags()
2019
        if archFlags:
2020
            archs = archFlags.split()
2021
        else:
2022
            archs = []
2023
    else:
2024
        archs = _archs.replace(' ', ';').split(';')
2025
    flags = [f'--offload-arch={arch}' for arch in archs]
2026
    flags += ['-fno-gpu-rdc']
2027
    return flags
2028

2029
def _get_build_directory(name: str, verbose: bool) -> str:
2030
    root_extensions_directory = os.environ.get('TORCH_EXTENSIONS_DIR')
2031
    if root_extensions_directory is None:
2032
        root_extensions_directory = get_default_build_root()
2033
        cu_str = ('cpu' if torch.version.cuda is None else
2034
                  f'cu{torch.version.cuda.replace(".", "")}')  # type: ignore[attr-defined]
2035
        python_version = f'py{sys.version_info.major}{sys.version_info.minor}'
2036
        build_folder = f'{python_version}_{cu_str}'
2037

2038
        root_extensions_directory = os.path.join(
2039
            root_extensions_directory, build_folder)
2040

2041
    if verbose:
2042
        print(f'Using {root_extensions_directory} as PyTorch extensions root...', file=sys.stderr)
2043

2044
    build_directory = os.path.join(root_extensions_directory, name)
2045
    if not os.path.exists(build_directory):
2046
        if verbose:
2047
            print(f'Creating extension directory {build_directory}...', file=sys.stderr)
2048
        # This is like mkdir -p, i.e. will also create parent directories.
2049
        os.makedirs(build_directory, exist_ok=True)
2050

2051
    return build_directory
2052

2053

2054
def _get_num_workers(verbose: bool) -> Optional[int]:
2055
    max_jobs = os.environ.get('MAX_JOBS')
2056
    if max_jobs is not None and max_jobs.isdigit():
2057
        if verbose:
2058
            print(f'Using envvar MAX_JOBS ({max_jobs}) as the number of workers...',
2059
                  file=sys.stderr)
2060
        return int(max_jobs)
2061
    if verbose:
2062
        print('Allowing ninja to set a default number of workers... '
2063
              '(overridable by setting the environment variable MAX_JOBS=N)',
2064
              file=sys.stderr)
2065
    return None
2066

2067

2068
def _run_ninja_build(build_directory: str, verbose: bool, error_prefix: str) -> None:
2069
    command = ['ninja', '-v']
2070
    num_workers = _get_num_workers(verbose)
2071
    if num_workers is not None:
2072
        command.extend(['-j', str(num_workers)])
2073
    env = os.environ.copy()
2074
    # Try to activate the vc env for the users
2075
    if IS_WINDOWS and 'VSCMD_ARG_TGT_ARCH' not in env:
2076
        from setuptools import distutils
2077

2078
        plat_name = distutils.util.get_platform()
2079
        plat_spec = PLAT_TO_VCVARS[plat_name]
2080

2081
        vc_env = distutils._msvccompiler._get_vc_env(plat_spec)
2082
        vc_env = {k.upper(): v for k, v in vc_env.items()}
2083
        for k, v in env.items():
2084
            uk = k.upper()
2085
            if uk not in vc_env:
2086
                vc_env[uk] = v
2087
        env = vc_env
2088
    try:
2089
        sys.stdout.flush()
2090
        sys.stderr.flush()
2091
        # Warning: don't pass stdout=None to subprocess.run to get output.
2092
        # subprocess.run assumes that sys.__stdout__ has not been modified and
2093
        # attempts to write to it by default.  However, when we call _run_ninja_build
2094
        # from ahead-of-time cpp extensions, the following happens:
2095
        # 1) If the stdout encoding is not utf-8, setuptools detachs __stdout__.
2096
        #    https://github.com/pypa/setuptools/blob/7e97def47723303fafabe48b22168bbc11bb4821/setuptools/dist.py#L1110
2097
        #    (it probably shouldn't do this)
2098
        # 2) subprocess.run (on POSIX, with no stdout override) relies on
2099
        #    __stdout__ not being detached:
2100
        #    https://github.com/python/cpython/blob/c352e6c7446c894b13643f538db312092b351789/Lib/subprocess.py#L1214
2101
        # To work around this, we pass in the fileno directly and hope that
2102
        # it is valid.
2103
        stdout_fileno = 1
2104
        subprocess.run(
2105
            command,
2106
            stdout=stdout_fileno if verbose else subprocess.PIPE,
2107
            stderr=subprocess.STDOUT,
2108
            cwd=build_directory,
2109
            check=True,
2110
            env=env)
2111
    except subprocess.CalledProcessError as e:
2112
        # Python 2 and 3 compatible way of getting the error object.
2113
        _, error, _ = sys.exc_info()
2114
        # error.output contains the stdout and stderr of the build attempt.
2115
        message = error_prefix
2116
        # `error` is a CalledProcessError (which has an `output`) attribute, but
2117
        # mypy thinks it's Optional[BaseException] and doesn't narrow
2118
        if hasattr(error, 'output') and error.output:  # type: ignore[union-attr]
2119
            message += f": {error.output.decode(*SUBPROCESS_DECODE_ARGS)}"  # type: ignore[union-attr]
2120
        raise RuntimeError(message) from e
2121

2122

2123
def _get_exec_path(module_name, path):
2124
    if IS_WINDOWS and TORCH_LIB_PATH not in os.getenv('PATH', '').split(';'):
2125
        torch_lib_in_path = any(
2126
            os.path.exists(p) and os.path.samefile(p, TORCH_LIB_PATH)
2127
            for p in os.getenv('PATH', '').split(';')
2128
        )
2129
        if not torch_lib_in_path:
2130
            os.environ['PATH'] = f"{TORCH_LIB_PATH};{os.getenv('PATH', '')}"
2131
    return os.path.join(path, f'{module_name}{EXEC_EXT}')
2132

2133

2134
def _import_module_from_library(module_name, path, is_python_module):
2135
    filepath = os.path.join(path, f"{module_name}{LIB_EXT}")
2136
    if is_python_module:
2137
        # https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path
2138
        spec = importlib.util.spec_from_file_location(module_name, filepath)
2139
        assert spec is not None
2140
        module = importlib.util.module_from_spec(spec)
2141
        assert isinstance(spec.loader, importlib.abc.Loader)
2142
        spec.loader.exec_module(module)
2143
        return module
2144
    else:
2145
        torch.ops.load_library(filepath)
2146
        return filepath
2147

2148

2149
def _write_ninja_file_to_build_library(path,
2150
                                       name,
2151
                                       sources,
2152
                                       extra_cflags,
2153
                                       extra_cuda_cflags,
2154
                                       extra_ldflags,
2155
                                       extra_include_paths,
2156
                                       with_cuda,
2157
                                       is_standalone) -> None:
2158
    extra_cflags = [flag.strip() for flag in extra_cflags]
2159
    extra_cuda_cflags = [flag.strip() for flag in extra_cuda_cflags]
2160
    extra_ldflags = [flag.strip() for flag in extra_ldflags]
2161
    extra_include_paths = [flag.strip() for flag in extra_include_paths]
2162

2163
    # Turn into absolute paths so we can emit them into the ninja build
2164
    # file wherever it is.
2165
    user_includes = [os.path.abspath(file) for file in extra_include_paths]
2166

2167
    # include_paths() gives us the location of torch/extension.h
2168
    system_includes = include_paths(with_cuda)
2169
    # sysconfig.get_path('include') gives us the location of Python.h
2170
    # Explicitly specify 'posix_prefix' scheme on non-Windows platforms to workaround error on some MacOS
2171
    # installations where default `get_path` points to non-existing `/Library/Python/M.m/include` folder
2172
    python_include_path = sysconfig.get_path('include', scheme='nt' if IS_WINDOWS else 'posix_prefix')
2173
    if python_include_path is not None:
2174
        system_includes.append(python_include_path)
2175

2176
    common_cflags = []
2177
    if not is_standalone:
2178
        common_cflags.append(f'-DTORCH_EXTENSION_NAME={name}')
2179
        common_cflags.append('-DTORCH_API_INCLUDE_EXTENSION_H')
2180

2181
    common_cflags += [f"{x}" for x in _get_pybind11_abi_build_flags()]
2182

2183
    # Windows does not understand `-isystem` and quotes flags later.
2184
    if IS_WINDOWS:
2185
        common_cflags += [f'-I{include}' for include in user_includes + system_includes]
2186
    else:
2187
        common_cflags += [f'-I{shlex.quote(include)}' for include in user_includes]
2188
        common_cflags += [f'-isystem {shlex.quote(include)}' for include in system_includes]
2189

2190
    common_cflags += [f"{x}" for x in _get_glibcxx_abi_build_flags()]
2191

2192
    if IS_WINDOWS:
2193
        cflags = common_cflags + COMMON_MSVC_FLAGS + ['/std:c++17'] + extra_cflags
2194
        cflags = _nt_quote_args(cflags)
2195
    else:
2196
        cflags = common_cflags + ['-fPIC', '-std=c++17'] + extra_cflags
2197

2198
    if with_cuda and IS_HIP_EXTENSION:
2199
        cuda_flags = ['-DWITH_HIP'] + cflags + COMMON_HIP_FLAGS + COMMON_HIPCC_FLAGS
2200
        cuda_flags += extra_cuda_cflags
2201
        cuda_flags += _get_rocm_arch_flags(cuda_flags)
2202
    elif with_cuda:
2203
        cuda_flags = common_cflags + COMMON_NVCC_FLAGS + _get_cuda_arch_flags()
2204
        if IS_WINDOWS:
2205
            for flag in COMMON_MSVC_FLAGS:
2206
                cuda_flags = ['-Xcompiler', flag] + cuda_flags
2207
            for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS:
2208
                cuda_flags = ['-Xcudafe', '--diag_suppress=' + ignore_warning] + cuda_flags
2209
            cuda_flags = cuda_flags + ['-std=c++17']
2210
            cuda_flags = _nt_quote_args(cuda_flags)
2211
            cuda_flags += _nt_quote_args(extra_cuda_cflags)
2212
        else:
2213
            cuda_flags += ['--compiler-options', "'-fPIC'"]
2214
            cuda_flags += extra_cuda_cflags
2215
            if not any(flag.startswith('-std=') for flag in cuda_flags):
2216
                cuda_flags.append('-std=c++17')
2217
            cc_env = os.getenv("CC")
2218
            if cc_env is not None:
2219
                cuda_flags = ['-ccbin', cc_env] + cuda_flags
2220
    else:
2221
        cuda_flags = None
2222

2223
    def object_file_path(source_file: str) -> str:
2224
        # '/path/to/file.cpp' -> 'file'
2225
        file_name = os.path.splitext(os.path.basename(source_file))[0]
2226
        if _is_cuda_file(source_file) and with_cuda:
2227
            # Use a different object filename in case a C++ and CUDA file have
2228
            # the same filename but different extension (.cpp vs. .cu).
2229
            target = f'{file_name}.cuda.o'
2230
        else:
2231
            target = f'{file_name}.o'
2232
        return target
2233

2234
    objects = [object_file_path(src) for src in sources]
2235
    ldflags = ([] if is_standalone else [SHARED_FLAG]) + extra_ldflags
2236

2237
    # The darwin linker needs explicit consent to ignore unresolved symbols.
2238
    if IS_MACOS:
2239
        ldflags.append('-undefined dynamic_lookup')
2240
    elif IS_WINDOWS:
2241
        ldflags = _nt_quote_args(ldflags)
2242

2243
    ext = EXEC_EXT if is_standalone else LIB_EXT
2244
    library_target = f'{name}{ext}'
2245

2246
    _write_ninja_file(
2247
        path=path,
2248
        cflags=cflags,
2249
        post_cflags=None,
2250
        cuda_cflags=cuda_flags,
2251
        cuda_post_cflags=None,
2252
        cuda_dlink_post_cflags=None,
2253
        sources=sources,
2254
        objects=objects,
2255
        ldflags=ldflags,
2256
        library_target=library_target,
2257
        with_cuda=with_cuda)
2258

2259

2260
def _write_ninja_file(path,
2261
                      cflags,
2262
                      post_cflags,
2263
                      cuda_cflags,
2264
                      cuda_post_cflags,
2265
                      cuda_dlink_post_cflags,
2266
                      sources,
2267
                      objects,
2268
                      ldflags,
2269
                      library_target,
2270
                      with_cuda) -> None:
2271
    r"""Write a ninja file that does the desired compiling and linking.
2272

2273
    `path`: Where to write this file
2274
    `cflags`: list of flags to pass to $cxx. Can be None.
2275
    `post_cflags`: list of flags to append to the $cxx invocation. Can be None.
2276
    `cuda_cflags`: list of flags to pass to $nvcc. Can be None.
2277
    `cuda_postflags`: list of flags to append to the $nvcc invocation. Can be None.
2278
    `sources`: list of paths to source files
2279
    `objects`: list of desired paths to objects, one per source.
2280
    `ldflags`: list of flags to pass to linker. Can be None.
2281
    `library_target`: Name of the output library. Can be None; in that case,
2282
                      we do no linking.
2283
    `with_cuda`: If we should be compiling with CUDA.
2284
    """
2285
    def sanitize_flags(flags):
2286
        if flags is None:
2287
            return []
2288
        else:
2289
            return [flag.strip() for flag in flags]
2290

2291
    cflags = sanitize_flags(cflags)
2292
    post_cflags = sanitize_flags(post_cflags)
2293
    cuda_cflags = sanitize_flags(cuda_cflags)
2294
    cuda_post_cflags = sanitize_flags(cuda_post_cflags)
2295
    cuda_dlink_post_cflags = sanitize_flags(cuda_dlink_post_cflags)
2296
    ldflags = sanitize_flags(ldflags)
2297

2298
    # Sanity checks...
2299
    assert len(sources) == len(objects)
2300
    assert len(sources) > 0
2301

2302
    compiler = get_cxx_compiler()
2303

2304
    # Version 1.3 is required for the `deps` directive.
2305
    config = ['ninja_required_version = 1.3']
2306
    config.append(f'cxx = {compiler}')
2307
    if with_cuda or cuda_dlink_post_cflags:
2308
        if "PYTORCH_NVCC" in os.environ:
2309
            nvcc = os.getenv("PYTORCH_NVCC")    # user can set nvcc compiler with ccache using the environment variable here
2310
        else:
2311
            if IS_HIP_EXTENSION:
2312
                nvcc = _join_rocm_home('bin', 'hipcc')
2313
            else:
2314
                nvcc = _join_cuda_home('bin', 'nvcc')
2315
        config.append(f'nvcc = {nvcc}')
2316

2317
    if IS_HIP_EXTENSION:
2318
        post_cflags = COMMON_HIP_FLAGS + post_cflags
2319
    flags = [f'cflags = {" ".join(cflags)}']
2320
    flags.append(f'post_cflags = {" ".join(post_cflags)}')
2321
    if with_cuda:
2322
        flags.append(f'cuda_cflags = {" ".join(cuda_cflags)}')
2323
        flags.append(f'cuda_post_cflags = {" ".join(cuda_post_cflags)}')
2324
    flags.append(f'cuda_dlink_post_cflags = {" ".join(cuda_dlink_post_cflags)}')
2325
    flags.append(f'ldflags = {" ".join(ldflags)}')
2326

2327
    # Turn into absolute paths so we can emit them into the ninja build
2328
    # file wherever it is.
2329
    sources = [os.path.abspath(file) for file in sources]
2330

2331
    # See https://ninja-build.org/build.ninja.html for reference.
2332
    compile_rule = ['rule compile']
2333
    if IS_WINDOWS:
2334
        compile_rule.append(
2335
            '  command = cl /showIncludes $cflags -c $in /Fo$out $post_cflags')
2336
        compile_rule.append('  deps = msvc')
2337
    else:
2338
        compile_rule.append(
2339
            '  command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags')
2340
        compile_rule.append('  depfile = $out.d')
2341
        compile_rule.append('  deps = gcc')
2342

2343
    if with_cuda:
2344
        cuda_compile_rule = ['rule cuda_compile']
2345
        nvcc_gendeps = ''
2346
        # --generate-dependencies-with-compile is not supported by ROCm
2347
        # Nvcc flag `--generate-dependencies-with-compile` is not supported by sccache, which may increase build time.
2348
        if torch.version.cuda is not None and os.getenv('TORCH_EXTENSION_SKIP_NVCC_GEN_DEPENDENCIES', '0') != '1':
2349
            cuda_compile_rule.append('  depfile = $out.d')
2350
            cuda_compile_rule.append('  deps = gcc')
2351
            # Note: non-system deps with nvcc are only supported
2352
            # on Linux so use --generate-dependencies-with-compile
2353
            # to make this work on Windows too.
2354
            nvcc_gendeps = '--generate-dependencies-with-compile --dependency-output $out.d'
2355
        cuda_compile_rule.append(
2356
            f'  command = $nvcc {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags')
2357

2358
    # Emit one build rule per source to enable incremental build.
2359
    build = []
2360
    for source_file, object_file in zip(sources, objects):
2361
        is_cuda_source = _is_cuda_file(source_file) and with_cuda
2362
        rule = 'cuda_compile' if is_cuda_source else 'compile'
2363
        if IS_WINDOWS:
2364
            source_file = source_file.replace(':', '$:')
2365
            object_file = object_file.replace(':', '$:')
2366
        source_file = source_file.replace(" ", "$ ")
2367
        object_file = object_file.replace(" ", "$ ")
2368
        build.append(f'build {object_file}: {rule} {source_file}')
2369

2370
    if cuda_dlink_post_cflags:
2371
        devlink_out = os.path.join(os.path.dirname(objects[0]), 'dlink.o')
2372
        devlink_rule = ['rule cuda_devlink']
2373
        devlink_rule.append('  command = $nvcc $in -o $out $cuda_dlink_post_cflags')
2374
        devlink = [f'build {devlink_out}: cuda_devlink {" ".join(objects)}']
2375
        objects += [devlink_out]
2376
    else:
2377
        devlink_rule, devlink = [], []
2378

2379
    if library_target is not None:
2380
        link_rule = ['rule link']
2381
        if IS_WINDOWS:
2382
            cl_paths = subprocess.check_output(['where',
2383
                                                'cl']).decode(*SUBPROCESS_DECODE_ARGS).split('\r\n')
2384
            if len(cl_paths) >= 1:
2385
                cl_path = os.path.dirname(cl_paths[0]).replace(':', '$:')
2386
            else:
2387
                raise RuntimeError("MSVC is required to load C++ extensions")
2388
            link_rule.append(f'  command = "{cl_path}/link.exe" $in /nologo $ldflags /out:$out')
2389
        else:
2390
            link_rule.append('  command = $cxx $in $ldflags -o $out')
2391

2392
        link = [f'build {library_target}: link {" ".join(objects)}']
2393

2394
        default = [f'default {library_target}']
2395
    else:
2396
        link_rule, link, default = [], [], []
2397

2398
    # 'Blocks' should be separated by newlines, for visual benefit.
2399
    blocks = [config, flags, compile_rule]
2400
    if with_cuda:
2401
        blocks.append(cuda_compile_rule)  # type: ignore[possibly-undefined]
2402
    blocks += [devlink_rule, link_rule, build, devlink, link, default]
2403
    content = "\n\n".join("\n".join(b) for b in blocks)
2404
    # Ninja requires a new lines at the end of the .ninja file
2405
    content += "\n"
2406
    _maybe_write(path, content)
2407

2408
def _join_cuda_home(*paths) -> str:
2409
    """
2410
    Join paths with CUDA_HOME, or raises an error if it CUDA_HOME is not set.
2411

2412
    This is basically a lazy way of raising an error for missing $CUDA_HOME
2413
    only once we need to get any CUDA-specific path.
2414
    """
2415
    if CUDA_HOME is None:
2416
        raise OSError('CUDA_HOME environment variable is not set. '
2417
                      'Please set it to your CUDA install root.')
2418
    return os.path.join(CUDA_HOME, *paths)
2419

2420

2421
def _is_cuda_file(path: str) -> bool:
2422
    valid_ext = ['.cu', '.cuh']
2423
    if IS_HIP_EXTENSION:
2424
        valid_ext.append('.hip')
2425
    return os.path.splitext(path)[1] in valid_ext
2426

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.