1
# mypy: allow-untyped-defs
16
from pathlib import Path
21
from .file_baton import FileBaton
22
from ._cpp_extension_versioner import ExtensionVersioner
23
from .hipify import hipify_python
24
from .hipify.hipify_python import GeneratedFileCleaner
25
from typing import Dict, List, Optional, Union, Tuple
26
from torch.torch_version import TorchVersion, Version
28
from setuptools.command.build_ext import build_ext
30
IS_WINDOWS = sys.platform == 'win32'
31
IS_MACOS = sys.platform.startswith('darwin')
32
IS_LINUX = sys.platform.startswith('linux')
33
LIB_EXT = '.pyd' if IS_WINDOWS else '.so'
34
EXEC_EXT = '.exe' if IS_WINDOWS else ''
35
CLIB_PREFIX = '' if IS_WINDOWS else 'lib'
36
CLIB_EXT = '.dll' if IS_WINDOWS else '.so'
37
SHARED_FLAG = '/DLL' if IS_WINDOWS else '-shared'
39
_HERE = os.path.abspath(__file__)
40
_TORCH_PATH = os.path.dirname(os.path.dirname(_HERE))
41
TORCH_LIB_PATH = os.path.join(_TORCH_PATH, 'lib')
44
SUBPROCESS_DECODE_ARGS = ('oem',) if IS_WINDOWS else ()
45
MINIMUM_GCC_VERSION = (5, 0, 0)
46
MINIMUM_MSVC_VERSION = (19, 0, 24215)
48
VersionRange = Tuple[Tuple[int, ...], Tuple[int, ...]]
49
VersionMap = Dict[str, VersionRange]
50
# The following values were taken from the following GitHub gist that
51
# summarizes the minimum valid major versions of g++/clang++ for each supported
52
# CUDA version: https://gist.github.com/ax3l/9489132
53
# Or from include/crt/host_config.h in the CUDA SDK
54
# The second value is the exclusive(!) upper bound, i.e. min <= version < max
55
CUDA_GCC_VERSIONS: VersionMap = {
56
'11.0': (MINIMUM_GCC_VERSION, (10, 0)),
57
'11.1': (MINIMUM_GCC_VERSION, (11, 0)),
58
'11.2': (MINIMUM_GCC_VERSION, (11, 0)),
59
'11.3': (MINIMUM_GCC_VERSION, (11, 0)),
60
'11.4': ((6, 0, 0), (12, 0)),
61
'11.5': ((6, 0, 0), (12, 0)),
62
'11.6': ((6, 0, 0), (12, 0)),
63
'11.7': ((6, 0, 0), (12, 0)),
66
MINIMUM_CLANG_VERSION = (3, 3, 0)
67
CUDA_CLANG_VERSIONS: VersionMap = {
68
'11.1': (MINIMUM_CLANG_VERSION, (11, 0)),
69
'11.2': (MINIMUM_CLANG_VERSION, (12, 0)),
70
'11.3': (MINIMUM_CLANG_VERSION, (12, 0)),
71
'11.4': (MINIMUM_CLANG_VERSION, (13, 0)),
72
'11.5': (MINIMUM_CLANG_VERSION, (13, 0)),
73
'11.6': (MINIMUM_CLANG_VERSION, (14, 0)),
74
'11.7': (MINIMUM_CLANG_VERSION, (14, 0)),
77
__all__ = ["get_default_build_root", "check_compiler_ok_for_platform", "get_compiler_abi_compatibility_and_version", "BuildExtension",
78
"CppExtension", "CUDAExtension", "include_paths", "library_paths", "load", "load_inline", "is_ninja_available",
79
"verify_ninja_availability", "remove_extension_h_precompiler_headers", "get_cxx_compiler", "check_compiler_is_gcc"]
80
# Taken directly from python stdlib < 3.9
81
# See https://github.com/pytorch/pytorch/issues/48617
82
def _nt_quote_args(args: Optional[List[str]]) -> List[str]:
83
"""Quote command-line arguments for DOS/Windows conventions.
85
Just wraps every argument which contains blanks in double quotes, and
86
returns a new argument list.
91
return [f'"{arg}"' if ' ' in arg else arg for arg in args]
93
def _find_cuda_home() -> Optional[str]:
94
"""Find the CUDA install path."""
96
cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
99
nvcc_path = shutil.which("nvcc")
100
if nvcc_path is not None:
101
cuda_home = os.path.dirname(os.path.dirname(nvcc_path))
105
cuda_homes = glob.glob(
106
'C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*')
107
if len(cuda_homes) == 0:
110
cuda_home = cuda_homes[0]
112
cuda_home = '/usr/local/cuda'
113
if not os.path.exists(cuda_home):
115
if cuda_home and not torch.cuda.is_available():
116
print(f"No CUDA runtime is found, using CUDA_HOME='{cuda_home}'",
120
def _find_rocm_home() -> Optional[str]:
121
"""Find the ROCm install path."""
123
rocm_home = os.environ.get('ROCM_HOME') or os.environ.get('ROCM_PATH')
124
if rocm_home is None:
126
hipcc_path = shutil.which('hipcc')
127
if hipcc_path is not None:
128
rocm_home = os.path.dirname(os.path.dirname(
129
os.path.realpath(hipcc_path)))
130
# can be either <ROCM_HOME>/hip/bin/hipcc or <ROCM_HOME>/bin/hipcc
131
if os.path.basename(rocm_home) == 'hip':
132
rocm_home = os.path.dirname(rocm_home)
135
fallback_path = '/opt/rocm'
136
if os.path.exists(fallback_path):
137
rocm_home = fallback_path
138
if rocm_home and torch.version.hip is None:
139
print(f"No ROCm runtime is found, using ROCM_HOME='{rocm_home}'",
144
def _join_rocm_home(*paths) -> str:
146
Join paths with ROCM_HOME, or raises an error if it ROCM_HOME is not set.
148
This is basically a lazy way of raising an error for missing $ROCM_HOME
149
only once we need to get any ROCm-specific path.
151
if ROCM_HOME is None:
152
raise OSError('ROCM_HOME environment variable is not set. '
153
'Please set it to your ROCm install root.')
155
raise OSError('Building PyTorch extensions using '
156
'ROCm and Windows is not supported.')
157
return os.path.join(ROCM_HOME, *paths)
160
ABI_INCOMPATIBILITY_WARNING = '''
164
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
165
Your compiler ({}) may be ABI-incompatible with PyTorch!
166
Please use a compiler that is ABI-compatible with GCC 5.0 and above.
167
See https://gcc.gnu.org/onlinedocs/libstdc++/manual/abi.html.
169
See https://gist.github.com/goldsborough/d466f43e8ffc948ff92de7486c5216d6
170
for instructions on how to install GCC 5 or higher.
171
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
175
WRONG_COMPILER_WARNING = '''
179
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
180
Your compiler ({user_compiler}) is not compatible with the compiler Pytorch was
181
built with for this platform, which is {pytorch_compiler} on {platform}. Please
182
use {pytorch_compiler} to to compile your extension. Alternatively, you may
183
compile PyTorch from source using {user_compiler}, and then you can also use
184
{user_compiler} to compile your extension.
186
See https://github.com/pytorch/pytorch/blob/master/CONTRIBUTING.md for help
187
with compiling PyTorch from source.
188
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
192
CUDA_MISMATCH_MESSAGE = '''
193
The detected CUDA version ({0}) mismatches the version that was used to compile
194
PyTorch ({1}). Please make sure to use the same CUDA versions.
196
CUDA_MISMATCH_WARN = "The detected CUDA version ({0}) has a minor version mismatch with the version that was used to compile PyTorch ({1}). Most likely this shouldn't be a problem."
197
CUDA_NOT_FOUND_MESSAGE = '''
198
CUDA was not found on the system, please set the CUDA_HOME or the CUDA_PATH
199
environment variable or add NVCC to your system PATH. The extension compilation will fail.
201
ROCM_HOME = _find_rocm_home()
202
HIP_HOME = _join_rocm_home('hip') if ROCM_HOME else None
203
IS_HIP_EXTENSION = True if ((ROCM_HOME is not None) and (torch.version.hip is not None)) else False
205
if torch.version.hip is not None:
206
ROCM_VERSION = tuple(int(v) for v in torch.version.hip.split('.')[:2])
208
CUDA_HOME = _find_cuda_home() if torch.cuda._is_compiled() else None
209
CUDNN_HOME = os.environ.get('CUDNN_HOME') or os.environ.get('CUDNN_PATH')
210
# PyTorch releases have the version pattern major.minor.patch, whereas when
211
# PyTorch is built from source, we append the git commit hash, which gives
212
# it the below pattern.
213
BUILT_FROM_SOURCE_VERSION_PATTERN = re.compile(r'\d+\.\d+\.\d+\w+\+\w+')
215
COMMON_MSVC_FLAGS = ['/MD', '/wd4819', '/wd4251', '/wd4244', '/wd4267', '/wd4275', '/wd4018', '/wd4190', '/wd4624', '/wd4067', '/wd4068', '/EHsc']
217
MSVC_IGNORE_CUDAFE_WARNINGS = [
218
'base_class_has_different_dll_interface',
219
'field_without_dll_interface',
220
'dll_interface_conflict_none_assumed',
221
'dll_interface_conflict_dllexport_assumed'
225
'-D__CUDA_NO_HALF_OPERATORS__',
226
'-D__CUDA_NO_HALF_CONVERSIONS__',
227
'-D__CUDA_NO_BFLOAT16_CONVERSIONS__',
228
'-D__CUDA_NO_HALF2_OPERATORS__',
229
'--expt-relaxed-constexpr'
234
'-D__HIP_PLATFORM_AMD__=1',
239
COMMON_HIPCC_FLAGS = [
241
'-D__HIP_NO_HALF_OPERATORS__=1',
242
'-D__HIP_NO_HALF_CONVERSIONS__=1',
245
JIT_EXTENSION_VERSIONER = ExtensionVersioner()
249
'win-amd64' : 'x86_amd64',
252
def get_cxx_compiler():
254
compiler = os.environ.get('CXX', 'cl')
256
compiler = os.environ.get('CXX', 'c++')
259
def _is_binary_build() -> bool:
260
return not BUILT_FROM_SOURCE_VERSION_PATTERN.match(torch.version.__version__)
263
def _accepted_compilers_for_platform() -> List[str]:
264
# gnu-c++ and gnu-cc are the conda gcc compilers
265
return ['clang++', 'clang'] if IS_MACOS else ['g++', 'gcc', 'gnu-c++', 'gnu-cc', 'clang++', 'clang']
267
def _maybe_write(filename, new_content):
269
Equivalent to writing the content into the file but will not touch the file
270
if it already had the right content (to avoid triggering recompile).
272
if os.path.exists(filename):
273
with open(filename) as f:
276
if content == new_content:
277
# The file already contains the right thing!
280
with open(filename, 'w') as source_file:
281
source_file.write(new_content)
283
def get_default_build_root() -> str:
285
Return the path to the root folder under which extensions will built.
287
For each extension module built, there will be one folder underneath the
288
folder returned by this function. For example, if ``p`` is the path
289
returned by this function and ``ext`` the name of an extension, the build
290
folder for the extension will be ``p/ext``.
292
This directory is **user-specific** so that multiple users on the same
293
machine won't meet permission issues.
295
return os.path.realpath(torch._appdirs.user_cache_dir(appname='torch_extensions'))
298
def check_compiler_ok_for_platform(compiler: str) -> bool:
300
Verify that the compiler is the expected one for the current platform.
303
compiler (str): The compiler executable to check.
306
True if the compiler is gcc/g++ on Linux or clang/clang++ on macOS,
307
and always True for Windows.
311
compiler_path = shutil.which(compiler)
312
if compiler_path is None:
314
# Use os.path.realpath to resolve any symlinks, in particular from 'c++' to e.g. 'g++'.
315
compiler_path = os.path.realpath(compiler_path)
316
# Check the compiler name
317
if any(name in compiler_path for name in _accepted_compilers_for_platform()):
319
# If compiler wrapper is used try to infer the actual compiler by invoking it with -v flag
320
env = os.environ.copy()
321
env['LC_ALL'] = 'C' # Don't localize output
322
version_string = subprocess.check_output([compiler, '-v'], stderr=subprocess.STDOUT, env=env).decode(*SUBPROCESS_DECODE_ARGS)
324
# Check for 'gcc' or 'g++' for sccache wrapper
325
pattern = re.compile("^COLLECT_GCC=(.*)$", re.MULTILINE)
326
results = re.findall(pattern, version_string)
327
if len(results) != 1:
328
# Clang is also a supported compiler on Linux
329
# Though on Ubuntu it's sometimes called "Ubuntu clang version"
330
return 'clang version' in version_string
331
compiler_path = os.path.realpath(results[0].strip())
332
# On RHEL/CentOS c++ is a gcc compiler wrapper
333
if os.path.basename(compiler_path) == 'c++' and 'gcc version' in version_string:
335
return any(name in compiler_path for name in _accepted_compilers_for_platform())
337
# Check for 'clang' or 'clang++'
338
return version_string.startswith("Apple clang")
342
def get_compiler_abi_compatibility_and_version(compiler) -> Tuple[bool, TorchVersion]:
344
Determine if the given compiler is ABI-compatible with PyTorch alongside its version.
347
compiler (str): The compiler executable name to check (e.g. ``g++``).
348
Must be executable in a shell process.
351
A tuple that contains a boolean that defines if the compiler is (likely) ABI-incompatible with PyTorch,
352
followed by a `TorchVersion` string that contains the compiler version separated by dots.
354
if not _is_binary_build():
355
return (True, TorchVersion('0.0.0'))
356
if os.environ.get('TORCH_DONT_CHECK_COMPILER_ABI') in ['ON', '1', 'YES', 'TRUE', 'Y']:
357
return (True, TorchVersion('0.0.0'))
359
# First check if the compiler is one of the expected ones for the particular platform.
360
if not check_compiler_ok_for_platform(compiler):
361
warnings.warn(WRONG_COMPILER_WARNING.format(
362
user_compiler=compiler,
363
pytorch_compiler=_accepted_compilers_for_platform()[0],
364
platform=sys.platform))
365
return (False, TorchVersion('0.0.0'))
368
# There is no particular minimum version we need for clang, so we're good here.
369
return (True, TorchVersion('0.0.0'))
372
minimum_required_version = MINIMUM_GCC_VERSION
373
versionstr = subprocess.check_output([compiler, '-dumpfullversion', '-dumpversion'])
374
version = versionstr.decode(*SUBPROCESS_DECODE_ARGS).strip().split('.')
376
minimum_required_version = MINIMUM_MSVC_VERSION
377
compiler_info = subprocess.check_output(compiler, stderr=subprocess.STDOUT)
378
match = re.search(r'(\d+)\.(\d+)\.(\d+)', compiler_info.decode(*SUBPROCESS_DECODE_ARGS).strip())
379
version = ['0', '0', '0'] if match is None else list(match.groups())
381
_, error, _ = sys.exc_info()
382
warnings.warn(f'Error checking compiler version for {compiler}: {error}')
383
return (False, TorchVersion('0.0.0'))
385
if tuple(map(int, version)) >= minimum_required_version:
386
return (True, TorchVersion('.'.join(version)))
388
compiler = f'{compiler} {".".join(version)}'
389
warnings.warn(ABI_INCOMPATIBILITY_WARNING.format(compiler))
391
return (False, TorchVersion('.'.join(version)))
394
def _check_cuda_version(compiler_name: str, compiler_version: TorchVersion) -> None:
396
raise RuntimeError(CUDA_NOT_FOUND_MESSAGE)
398
nvcc = os.path.join(CUDA_HOME, 'bin', 'nvcc')
399
cuda_version_str = subprocess.check_output([nvcc, '--version']).strip().decode(*SUBPROCESS_DECODE_ARGS)
400
cuda_version = re.search(r'release (\d+[.]\d+)', cuda_version_str)
401
if cuda_version is None:
404
cuda_str_version = cuda_version.group(1)
405
cuda_ver = Version(cuda_str_version)
406
if torch.version.cuda is None:
409
torch_cuda_version = Version(torch.version.cuda)
410
if cuda_ver != torch_cuda_version:
411
# major/minor attributes are only available in setuptools>=49.4.0
412
if getattr(cuda_ver, "major", None) is None:
413
raise ValueError("setuptools>=49.4.0 is required")
414
if cuda_ver.major != torch_cuda_version.major:
415
raise RuntimeError(CUDA_MISMATCH_MESSAGE.format(cuda_str_version, torch.version.cuda))
416
warnings.warn(CUDA_MISMATCH_WARN.format(cuda_str_version, torch.version.cuda))
418
if not (sys.platform.startswith('linux') and
419
os.environ.get('TORCH_DONT_CHECK_COMPILER_ABI') not in ['ON', '1', 'YES', 'TRUE', 'Y'] and
423
cuda_compiler_bounds: VersionMap = CUDA_CLANG_VERSIONS if compiler_name.startswith('clang') else CUDA_GCC_VERSIONS
425
if cuda_str_version not in cuda_compiler_bounds:
426
warnings.warn(f'There are no {compiler_name} version bounds defined for CUDA version {cuda_str_version}')
428
min_compiler_version, max_excl_compiler_version = cuda_compiler_bounds[cuda_str_version]
429
# Special case for 11.4.0, which has lower compiler bounds than 11.4.1
430
if "V11.4.48" in cuda_version_str and cuda_compiler_bounds == CUDA_GCC_VERSIONS:
431
max_excl_compiler_version = (11, 0)
432
min_compiler_version_str = '.'.join(map(str, min_compiler_version))
433
max_excl_compiler_version_str = '.'.join(map(str, max_excl_compiler_version))
435
version_bound_str = f'>={min_compiler_version_str}, <{max_excl_compiler_version_str}'
437
if compiler_version < TorchVersion(min_compiler_version_str):
439
f'The current installed version of {compiler_name} ({compiler_version}) is less '
440
f'than the minimum required version by CUDA {cuda_str_version} ({min_compiler_version_str}). '
441
f'Please make sure to use an adequate version of {compiler_name} ({version_bound_str}).'
443
if compiler_version >= TorchVersion(max_excl_compiler_version_str):
445
f'The current installed version of {compiler_name} ({compiler_version}) is greater '
446
f'than the maximum required version by CUDA {cuda_str_version}. '
447
f'Please make sure to use an adequate version of {compiler_name} ({version_bound_str}).'
451
class BuildExtension(build_ext):
453
A custom :mod:`setuptools` build extension .
455
This :class:`setuptools.build_ext` subclass takes care of passing the
456
minimum required compiler flags (e.g. ``-std=c++17``) as well as mixed
457
C++/CUDA compilation (and support for CUDA files in general).
459
When using :class:`BuildExtension`, it is allowed to supply a dictionary
460
for ``extra_compile_args`` (rather than the usual list) that maps from
461
languages (``cxx`` or ``nvcc``) to a list of additional compiler flags to
462
supply to the compiler. This makes it possible to supply different flags to
463
the C++ and CUDA compiler during mixed compilation.
465
``use_ninja`` (bool): If ``use_ninja`` is ``True`` (default), then we
466
attempt to build using the Ninja backend. Ninja greatly speeds up
467
compilation compared to the standard ``setuptools.build_ext``.
468
Fallbacks to the standard distutils backend if Ninja is not available.
471
By default, the Ninja backend uses #CPUS + 2 workers to build the
472
extension. This may use up too many resources on some systems. One
473
can control the number of workers by setting the `MAX_JOBS` environment
474
variable to a non-negative number.
478
def with_options(cls, **options):
479
"""Return a subclass with alternative constructor that extends any original keyword arguments to the original constructor with the given options."""
480
class cls_with_options(cls): # type: ignore[misc, valid-type]
481
def __init__(self, *args, **kwargs):
482
kwargs.update(options)
483
super().__init__(*args, **kwargs)
485
return cls_with_options
487
def __init__(self, *args, **kwargs) -> None:
488
super().__init__(*args, **kwargs)
489
self.no_python_abi_suffix = kwargs.get("no_python_abi_suffix", False)
491
self.use_ninja = kwargs.get('use_ninja', True)
493
# Test if we can use ninja. Fallback otherwise.
494
msg = ('Attempted to use ninja as the BuildExtension backend but '
495
'{}. Falling back to using the slow distutils backend.')
496
if not is_ninja_available():
497
warnings.warn(msg.format('we could not find ninja.'))
498
self.use_ninja = False
500
def finalize_options(self) -> None:
501
super().finalize_options()
505
def build_extensions(self) -> None:
506
compiler_name, compiler_version = self._check_abi()
509
extension_iter = iter(self.extensions)
510
extension = next(extension_iter, None)
511
while not cuda_ext and extension:
512
for source in extension.sources:
513
_, ext = os.path.splitext(source)
517
extension = next(extension_iter, None)
519
if cuda_ext and not IS_HIP_EXTENSION:
520
_check_cuda_version(compiler_name, compiler_version)
522
for extension in self.extensions:
523
# Ensure at least an empty list of flags for 'cxx' and 'nvcc' when
524
# extra_compile_args is a dict. Otherwise, default torch flags do
525
# not get passed. Necessary when only one of 'cxx' and 'nvcc' is
526
# passed to extra_compile_args in CUDAExtension, i.e.
527
# CUDAExtension(..., extra_compile_args={'cxx': [...]})
529
# CUDAExtension(..., extra_compile_args={'nvcc': [...]})
530
if isinstance(extension.extra_compile_args, dict):
531
for ext in ['cxx', 'nvcc']:
532
if ext not in extension.extra_compile_args:
533
extension.extra_compile_args[ext] = []
535
self._add_compile_flag(extension, '-DTORCH_API_INCLUDE_EXTENSION_H')
536
# See note [Pybind11 ABI constants]
537
for name in ["COMPILER_TYPE", "STDLIB", "BUILD_ABI"]:
538
val = getattr(torch._C, f"_PYBIND11_{name}")
539
if val is not None and not IS_WINDOWS:
540
self._add_compile_flag(extension, f'-DPYBIND11_{name}="{val}"')
541
self._define_torch_extension_name(extension)
542
self._add_gnu_cpp_abi_flag(extension)
544
if 'nvcc_dlink' in extension.extra_compile_args:
545
assert self.use_ninja, f"With dlink=True, ninja is required to build cuda extension {extension.name}."
547
# Register .cu, .cuh, .hip, and .mm as valid source extensions.
548
self.compiler.src_extensions += ['.cu', '.cuh', '.hip']
549
if torch.backends.mps.is_built():
550
self.compiler.src_extensions += ['.mm']
551
# Save the original _compile method for later.
552
if self.compiler.compiler_type == 'msvc':
553
self.compiler._cpp_extensions += ['.cu', '.cuh']
554
original_compile = self.compiler.compile
555
original_spawn = self.compiler.spawn
557
original_compile = self.compiler._compile
559
def append_std17_if_no_std_present(cflags) -> None:
560
# NVCC does not allow multiple -std to be passed, so we avoid
561
# overriding the option if the user explicitly passed it.
562
cpp_format_prefix = '/{}:' if self.compiler.compiler_type == 'msvc' else '-{}='
563
cpp_flag_prefix = cpp_format_prefix.format('std')
564
cpp_flag = cpp_flag_prefix + 'c++17'
565
if not any(flag.startswith(cpp_flag_prefix) for flag in cflags):
566
cflags.append(cpp_flag)
568
def unix_cuda_flags(cflags):
569
cflags = (COMMON_NVCC_FLAGS +
570
['--compiler-options', "'-fPIC'"] +
571
cflags + _get_cuda_arch_flags(cflags))
573
# NVCC does not allow multiple -ccbin/--compiler-bindir to be passed, so we avoid
574
# overriding the option if the user explicitly passed it.
575
_ccbin = os.getenv("CC")
578
and not any(flag.startswith(('-ccbin', '--compiler-bindir')) for flag in cflags)
580
cflags.extend(['-ccbin', _ccbin])
584
def convert_to_absolute_paths_inplace(paths):
585
# Helper function. See Note [Absolute include_dirs]
586
if paths is not None:
587
for i in range(len(paths)):
588
if not os.path.isabs(paths[i]):
589
paths[i] = os.path.abspath(paths[i])
591
def unix_wrap_single_compile(obj, src, ext, cc_args, extra_postargs, pp_opts) -> None:
592
# Copy before we make any modifications.
593
cflags = copy.deepcopy(extra_postargs)
595
original_compiler = self.compiler.compiler_so
596
if _is_cuda_file(src):
597
nvcc = [_join_rocm_home('bin', 'hipcc') if IS_HIP_EXTENSION else _join_cuda_home('bin', 'nvcc')]
598
self.compiler.set_executable('compiler_so', nvcc)
599
if isinstance(cflags, dict):
600
cflags = cflags['nvcc']
602
cflags = COMMON_HIPCC_FLAGS + cflags + _get_rocm_arch_flags(cflags)
604
cflags = unix_cuda_flags(cflags)
605
elif isinstance(cflags, dict):
606
cflags = cflags['cxx']
608
cflags = COMMON_HIP_FLAGS + cflags
609
append_std17_if_no_std_present(cflags)
611
original_compile(obj, src, ext, cc_args, cflags, pp_opts)
613
# Put the original compiler back in place.
614
self.compiler.set_executable('compiler_so', original_compiler)
616
def unix_wrap_ninja_compile(sources,
624
r"""Compiles sources by outputting a ninja file and running it."""
625
# NB: I copied some lines from self.compiler (which is an instance
626
# of distutils.UnixCCompiler). See the following link.
627
# https://github.com/python/cpython/blob/f03a8f8d5001963ad5b5b28dbd95497e9cc15596/Lib/distutils/ccompiler.py#L564-L567
628
# This can be fragile, but a lot of other repos also do this
629
# (see https://github.com/search?q=_setup_compile&type=Code)
630
# so it is probably OK; we'll also get CI signal if/when
631
# we update our python version (which is when distutils can be
634
# Use absolute path for output_dir so that the object file paths
635
# (`objects`) get generated with absolute paths.
636
output_dir = os.path.abspath(output_dir)
638
# See Note [Absolute include_dirs]
639
convert_to_absolute_paths_inplace(self.compiler.include_dirs)
641
_, objects, extra_postargs, pp_opts, _ = \
642
self.compiler._setup_compile(output_dir, macros,
643
include_dirs, sources,
644
depends, extra_postargs)
645
common_cflags = self.compiler._get_cc_args(pp_opts, debug, extra_preargs)
646
extra_cc_cflags = self.compiler.compiler_so[1:]
647
with_cuda = any(map(_is_cuda_file, sources))
649
# extra_postargs can be either:
650
# - a dict mapping cxx/nvcc to extra flags
651
# - a list of extra flags.
652
if isinstance(extra_postargs, dict):
653
post_cflags = extra_postargs['cxx']
655
post_cflags = list(extra_postargs)
657
post_cflags = COMMON_HIP_FLAGS + post_cflags
658
append_std17_if_no_std_present(post_cflags)
660
cuda_post_cflags = None
663
cuda_cflags = common_cflags
664
if isinstance(extra_postargs, dict):
665
cuda_post_cflags = extra_postargs['nvcc']
667
cuda_post_cflags = list(extra_postargs)
669
cuda_post_cflags = cuda_post_cflags + _get_rocm_arch_flags(cuda_post_cflags)
670
cuda_post_cflags = COMMON_HIP_FLAGS + COMMON_HIPCC_FLAGS + cuda_post_cflags
672
cuda_post_cflags = unix_cuda_flags(cuda_post_cflags)
673
append_std17_if_no_std_present(cuda_post_cflags)
674
cuda_cflags = [shlex.quote(f) for f in cuda_cflags]
675
cuda_post_cflags = [shlex.quote(f) for f in cuda_post_cflags]
677
if isinstance(extra_postargs, dict) and 'nvcc_dlink' in extra_postargs:
678
cuda_dlink_post_cflags = unix_cuda_flags(extra_postargs['nvcc_dlink'])
680
cuda_dlink_post_cflags = None
681
_write_ninja_file_and_compile_objects(
684
cflags=[shlex.quote(f) for f in extra_cc_cflags + common_cflags],
685
post_cflags=[shlex.quote(f) for f in post_cflags],
686
cuda_cflags=cuda_cflags,
687
cuda_post_cflags=cuda_post_cflags,
688
cuda_dlink_post_cflags=cuda_dlink_post_cflags,
689
build_directory=output_dir,
693
# Return *all* object filenames, not just the ones we just built.
696
def win_cuda_flags(cflags):
697
return (COMMON_NVCC_FLAGS +
698
cflags + _get_cuda_arch_flags(cflags))
700
def win_wrap_single_compile(sources,
709
self.cflags = copy.deepcopy(extra_postargs)
710
extra_postargs = None
713
# Using regex to match src, obj and include files
714
src_regex = re.compile('/T(p|c)(.*)')
716
m.group(2) for m in (src_regex.match(elem) for elem in cmd)
720
obj_regex = re.compile('/Fo(.*)')
722
m.group(1) for m in (obj_regex.match(elem) for elem in cmd)
726
include_regex = re.compile(r'((\-|\/)I.*)')
729
for m in (include_regex.match(elem) for elem in cmd) if m
732
if len(src_list) >= 1 and len(obj_list) >= 1:
735
if _is_cuda_file(src):
736
nvcc = _join_cuda_home('bin', 'nvcc')
737
if isinstance(self.cflags, dict):
738
cflags = self.cflags['nvcc']
739
elif isinstance(self.cflags, list):
744
cflags = win_cuda_flags(cflags) + ['-std=c++17', '--use-local-env']
745
for flag in COMMON_MSVC_FLAGS:
746
cflags = ['-Xcompiler', flag] + cflags
747
for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS:
748
cflags = ['-Xcudafe', '--diag_suppress=' + ignore_warning] + cflags
749
cmd = [nvcc, '-c', src, '-o', obj] + include_list + cflags
750
elif isinstance(self.cflags, dict):
751
cflags = COMMON_MSVC_FLAGS + self.cflags['cxx']
752
append_std17_if_no_std_present(cflags)
754
elif isinstance(self.cflags, list):
755
cflags = COMMON_MSVC_FLAGS + self.cflags
756
append_std17_if_no_std_present(cflags)
759
return original_spawn(cmd)
762
self.compiler.spawn = spawn
763
return original_compile(sources, output_dir, macros,
764
include_dirs, debug, extra_preargs,
765
extra_postargs, depends)
767
self.compiler.spawn = original_spawn
769
def win_wrap_ninja_compile(sources,
778
if not self.compiler.initialized:
779
self.compiler.initialize()
780
output_dir = os.path.abspath(output_dir)
782
# Note [Absolute include_dirs]
783
# Convert relative path in self.compiler.include_dirs to absolute path if any,
784
# For ninja build, the build location is not local, the build happens
785
# in a in script created build folder, relative path lost their correctness.
786
# To be consistent with jit extension, we allow user to enter relative include_dirs
787
# in setuptools.setup, and we convert the relative path to absolute path here
788
convert_to_absolute_paths_inplace(self.compiler.include_dirs)
790
_, objects, extra_postargs, pp_opts, _ = \
791
self.compiler._setup_compile(output_dir, macros,
792
include_dirs, sources,
793
depends, extra_postargs)
794
common_cflags = extra_preargs or []
797
cflags.extend(self.compiler.compile_options_debug)
799
cflags.extend(self.compiler.compile_options)
800
common_cflags.extend(COMMON_MSVC_FLAGS)
801
cflags = cflags + common_cflags + pp_opts
802
with_cuda = any(map(_is_cuda_file, sources))
804
# extra_postargs can be either:
805
# - a dict mapping cxx/nvcc to extra flags
806
# - a list of extra flags.
807
if isinstance(extra_postargs, dict):
808
post_cflags = extra_postargs['cxx']
810
post_cflags = list(extra_postargs)
811
append_std17_if_no_std_present(post_cflags)
813
cuda_post_cflags = None
816
cuda_cflags = ['-std=c++17', '--use-local-env']
817
for common_cflag in common_cflags:
818
cuda_cflags.append('-Xcompiler')
819
cuda_cflags.append(common_cflag)
820
for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS:
821
cuda_cflags.append('-Xcudafe')
822
cuda_cflags.append('--diag_suppress=' + ignore_warning)
823
cuda_cflags.extend(pp_opts)
824
if isinstance(extra_postargs, dict):
825
cuda_post_cflags = extra_postargs['nvcc']
827
cuda_post_cflags = list(extra_postargs)
828
cuda_post_cflags = win_cuda_flags(cuda_post_cflags)
830
cflags = _nt_quote_args(cflags)
831
post_cflags = _nt_quote_args(post_cflags)
833
cuda_cflags = _nt_quote_args(cuda_cflags)
834
cuda_post_cflags = _nt_quote_args(cuda_post_cflags)
835
if isinstance(extra_postargs, dict) and 'nvcc_dlink' in extra_postargs:
836
cuda_dlink_post_cflags = win_cuda_flags(extra_postargs['nvcc_dlink'])
838
cuda_dlink_post_cflags = None
840
_write_ninja_file_and_compile_objects(
844
post_cflags=post_cflags,
845
cuda_cflags=cuda_cflags,
846
cuda_post_cflags=cuda_post_cflags,
847
cuda_dlink_post_cflags=cuda_dlink_post_cflags,
848
build_directory=output_dir,
852
# Return *all* object filenames, not just the ones we just built.
855
# Monkey-patch the _compile or compile method.
856
# https://github.com/python/cpython/blob/dc0284ee8f7a270b6005467f26d8e5773d76e959/Lib/distutils/ccompiler.py#L511
857
if self.compiler.compiler_type == 'msvc':
859
self.compiler.compile = win_wrap_ninja_compile
861
self.compiler.compile = win_wrap_single_compile
864
self.compiler.compile = unix_wrap_ninja_compile
866
self.compiler._compile = unix_wrap_single_compile
868
build_ext.build_extensions(self)
870
def get_ext_filename(self, ext_name):
871
# Get the original shared library name. For Python 3, this name will be
872
# suffixed with "<SOABI>.so", where <SOABI> will be something like
873
# cpython-37m-x86_64-linux-gnu.
874
ext_filename = super().get_ext_filename(ext_name)
875
# If `no_python_abi_suffix` is `True`, we omit the Python 3 ABI
876
# component. This makes building shared libraries with setuptools that
877
# aren't Python modules nicer.
878
if self.no_python_abi_suffix:
879
# The parts will be e.g. ["my_extension", "cpython-37m-x86_64-linux-gnu", "so"].
880
ext_filename_parts = ext_filename.split('.')
881
# Omit the second to last element.
882
without_abi = ext_filename_parts[:-2] + ext_filename_parts[-1:]
883
ext_filename = '.'.join(without_abi)
886
def _check_abi(self) -> Tuple[str, TorchVersion]:
887
# On some platforms, like Windows, compiler_cxx is not available.
888
if hasattr(self.compiler, 'compiler_cxx'):
889
compiler = self.compiler.compiler_cxx[0]
891
compiler = get_cxx_compiler()
892
_, version = get_compiler_abi_compatibility_and_version(compiler)
893
# Warn user if VC env is activated but `DISTUILS_USE_SDK` is not set.
894
if IS_WINDOWS and 'VSCMD_ARG_TGT_ARCH' in os.environ and 'DISTUTILS_USE_SDK' not in os.environ:
895
msg = ('It seems that the VC environment is activated but DISTUTILS_USE_SDK is not set.'
896
'This may lead to multiple activations of the VC env.'
897
'Please set `DISTUTILS_USE_SDK=1` and try again.')
898
raise UserWarning(msg)
899
return compiler, version
901
def _add_compile_flag(self, extension, flag):
902
extension.extra_compile_args = copy.deepcopy(extension.extra_compile_args)
903
if isinstance(extension.extra_compile_args, dict):
904
for args in extension.extra_compile_args.values():
907
extension.extra_compile_args.append(flag)
909
def _define_torch_extension_name(self, extension):
910
# pybind11 doesn't support dots in the names
911
# so in order to support extensions in the packages
912
# like torch._C, we take the last part of the string
913
# as the library name
914
names = extension.name.split('.')
916
define = f'-DTORCH_EXTENSION_NAME={name}'
917
self._add_compile_flag(extension, define)
919
def _add_gnu_cpp_abi_flag(self, extension):
920
# use the same CXX ABI as what PyTorch was compiled with
921
self._add_compile_flag(extension, '-D_GLIBCXX_USE_CXX11_ABI=' + str(int(torch._C._GLIBCXX_USE_CXX11_ABI)))
924
def CppExtension(name, sources, *args, **kwargs):
926
Create a :class:`setuptools.Extension` for C++.
928
Convenience method that creates a :class:`setuptools.Extension` with the
929
bare minimum (but often sufficient) arguments to build a C++ extension.
931
All arguments are forwarded to the :class:`setuptools.Extension`
932
constructor. Full list arguments can be found at
933
https://setuptools.pypa.io/en/latest/userguide/ext_modules.html#extension-api-reference
936
>>> # xdoctest: +SKIP
937
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT)
938
>>> from setuptools import setup
939
>>> from torch.utils.cpp_extension import BuildExtension, CppExtension
941
... name='extension',
944
... name='extension',
945
... sources=['extension.cpp'],
946
... extra_compile_args=['-g'],
947
... extra_link_args=['-Wl,--no-as-needed', '-lm'])
950
... 'build_ext': BuildExtension
953
include_dirs = kwargs.get('include_dirs', [])
954
include_dirs += include_paths()
955
kwargs['include_dirs'] = include_dirs
957
library_dirs = kwargs.get('library_dirs', [])
958
library_dirs += library_paths()
959
kwargs['library_dirs'] = library_dirs
961
libraries = kwargs.get('libraries', [])
962
libraries.append('c10')
963
libraries.append('torch')
964
libraries.append('torch_cpu')
965
libraries.append('torch_python')
967
libraries.append("sleef")
969
kwargs['libraries'] = libraries
971
kwargs['language'] = 'c++'
972
return setuptools.Extension(name, sources, *args, **kwargs)
975
def CUDAExtension(name, sources, *args, **kwargs):
977
Create a :class:`setuptools.Extension` for CUDA/C++.
979
Convenience method that creates a :class:`setuptools.Extension` with the
980
bare minimum (but often sufficient) arguments to build a CUDA/C++
981
extension. This includes the CUDA include path, library path and runtime
984
All arguments are forwarded to the :class:`setuptools.Extension`
985
constructor. Full list arguments can be found at
986
https://setuptools.pypa.io/en/latest/userguide/ext_modules.html#extension-api-reference
989
>>> # xdoctest: +SKIP
990
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT)
991
>>> from setuptools import setup
992
>>> from torch.utils.cpp_extension import BuildExtension, CUDAExtension
994
... name='cuda_extension',
997
... name='cuda_extension',
998
... sources=['extension.cpp', 'extension_kernel.cu'],
999
... extra_compile_args={'cxx': ['-g'],
1000
... 'nvcc': ['-O2']},
1001
... extra_link_args=['-Wl,--no-as-needed', '-lcuda'])
1004
... 'build_ext': BuildExtension
1007
Compute capabilities:
1009
By default the extension will be compiled to run on all archs of the cards visible during the
1010
building process of the extension, plus PTX. If down the road a new card is installed the
1011
extension may need to be recompiled. If a visible card has a compute capability (CC) that's
1012
newer than the newest version for which your nvcc can build fully-compiled binaries, Pytorch
1013
will make nvcc fall back to building kernels with the newest version of PTX your nvcc does
1014
support (see below for details on PTX).
1016
You can override the default behavior using `TORCH_CUDA_ARCH_LIST` to explicitly specify which
1017
CCs you want the extension to support:
1019
``TORCH_CUDA_ARCH_LIST="6.1 8.6" python build_my_extension.py``
1020
``TORCH_CUDA_ARCH_LIST="5.2 6.0 6.1 7.0 7.5 8.0 8.6+PTX" python build_my_extension.py``
1022
The +PTX option causes extension kernel binaries to include PTX instructions for the specified
1023
CC. PTX is an intermediate representation that allows kernels to runtime-compile for any CC >=
1024
the specified CC (for example, 8.6+PTX generates PTX that can runtime-compile for any GPU with
1025
CC >= 8.6). This improves your binary's forward compatibility. However, relying on older PTX to
1026
provide forward compat by runtime-compiling for newer CCs can modestly reduce performance on
1027
those newer CCs. If you know exact CC(s) of the GPUs you want to target, you're always better
1028
off specifying them individually. For example, if you want your extension to run on 8.0 and 8.6,
1029
"8.0+PTX" would work functionally because it includes PTX that can runtime-compile for 8.6, but
1030
"8.0 8.6" would be better.
1032
Note that while it's possible to include all supported archs, the more archs get included the
1033
slower the building process will be, as it will build a separate kernel image for each arch.
1035
Note that CUDA-11.5 nvcc will hit internal compiler error while parsing torch/extension.h on Windows.
1036
To workaround the issue, move python binding logic to pure C++ file.
1039
#include <ATen/ATen.h>
1040
at::Tensor SigmoidAlphaBlendForwardCuda(....)
1043
#include <torch/extension.h>
1044
torch::Tensor SigmoidAlphaBlendForwardCuda(...)
1046
Currently open issue for nvcc bug: https://github.com/pytorch/pytorch/issues/69460
1047
Complete workaround code example: https://github.com/facebookresearch/pytorch3d/commit/cb170ac024a949f1f9614ffe6af1c38d972f7d48
1049
Relocatable device code linking:
1051
If you want to reference device symbols across compilation units (across object files),
1052
the object files need to be built with `relocatable device code` (-rdc=true or -dc).
1053
An exception to this rule is "dynamic parallelism" (nested kernel launches) which is not used a lot anymore.
1054
`Relocatable device code` is less optimized so it needs to be used only on object files that need it.
1055
Using `-dlto` (Device Link Time Optimization) at the device code compilation step and `dlink` step
1056
help reduce the protentional perf degradation of `-rdc`.
1057
Note that it needs to be used at both steps to be useful.
1059
If you have `rdc` objects you need to have an extra `-dlink` (device linking) step before the CPU symbol linking step.
1060
There is also a case where `-dlink` is used without `-rdc`:
1061
when an extension is linked against a static lib containing rdc-compiled objects
1062
like the [NVSHMEM library](https://developer.nvidia.com/nvshmem).
1064
Note: Ninja is required to build a CUDA Extension with RDC linking.
1067
>>> # xdoctest: +SKIP
1068
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT)
1070
... name='cuda_extension',
1071
... sources=['extension.cpp', 'extension_kernel.cu'],
1073
... dlink_libraries=["dlink_lib"],
1074
... extra_compile_args={'cxx': ['-g'],
1075
... 'nvcc': ['-O2', '-rdc=true']})
1077
library_dirs = kwargs.get('library_dirs', [])
1078
library_dirs += library_paths(cuda=True)
1079
kwargs['library_dirs'] = library_dirs
1081
libraries = kwargs.get('libraries', [])
1082
libraries.append('c10')
1083
libraries.append('torch')
1084
libraries.append('torch_cpu')
1085
libraries.append('torch_python')
1086
if IS_HIP_EXTENSION:
1087
libraries.append('amdhip64')
1088
libraries.append('c10_hip')
1089
libraries.append('torch_hip')
1091
libraries.append('cudart')
1092
libraries.append('c10_cuda')
1093
libraries.append('torch_cuda')
1094
kwargs['libraries'] = libraries
1096
include_dirs = kwargs.get('include_dirs', [])
1098
if IS_HIP_EXTENSION:
1099
build_dir = os.getcwd()
1100
hipify_result = hipify_python.hipify(
1101
project_directory=build_dir,
1102
output_directory=build_dir,
1103
header_include_dirs=include_dirs,
1104
includes=[os.path.join(build_dir, '*')], # limit scope to build_dir only
1105
extra_files=[os.path.abspath(s) for s in sources],
1107
is_pytorch_extension=True,
1108
hipify_extra_files_only=True, # don't hipify everything in includes path
1111
hipified_sources = set()
1112
for source in sources:
1113
s_abs = os.path.abspath(source)
1114
hipified_s_abs = (hipify_result[s_abs].hipified_path if (s_abs in hipify_result and
1115
hipify_result[s_abs].hipified_path is not None) else s_abs)
1116
# setup() arguments must *always* be /-separated paths relative to the setup.py directory,
1117
# *never* absolute paths
1118
hipified_sources.add(os.path.relpath(hipified_s_abs, build_dir))
1120
sources = list(hipified_sources)
1122
include_dirs += include_paths(cuda=True)
1123
kwargs['include_dirs'] = include_dirs
1125
kwargs['language'] = 'c++'
1127
dlink_libraries = kwargs.get('dlink_libraries', [])
1128
dlink = kwargs.get('dlink', False) or dlink_libraries
1130
extra_compile_args = kwargs.get('extra_compile_args', {})
1132
extra_compile_args_dlink = extra_compile_args.get('nvcc_dlink', [])
1133
extra_compile_args_dlink += ['-dlink']
1134
extra_compile_args_dlink += [f'-L{x}' for x in library_dirs]
1135
extra_compile_args_dlink += [f'-l{x}' for x in dlink_libraries]
1137
if (torch.version.cuda is not None) and TorchVersion(torch.version.cuda) >= '11.2':
1138
extra_compile_args_dlink += ['-dlto'] # Device Link Time Optimization started from cuda 11.2
1140
extra_compile_args['nvcc_dlink'] = extra_compile_args_dlink
1142
kwargs['extra_compile_args'] = extra_compile_args
1144
return setuptools.Extension(name, sources, *args, **kwargs)
1147
def include_paths(cuda: bool = False) -> List[str]:
1149
Get the include paths required to build a C++ or CUDA extension.
1152
cuda: If `True`, includes CUDA-specific include paths.
1155
A list of include path strings.
1157
lib_include = os.path.join(_TORCH_PATH, 'include')
1160
# Remove this once torch/torch.h is officially no longer supported for C++ extensions.
1161
os.path.join(lib_include, 'torch', 'csrc', 'api', 'include'),
1162
# Some internal (old) Torch headers don't properly prefix their includes,
1163
# so we need to pass -Itorch/lib/include/TH as well.
1164
os.path.join(lib_include, 'TH'),
1165
os.path.join(lib_include, 'THC')
1167
if cuda and IS_HIP_EXTENSION:
1168
paths.append(os.path.join(lib_include, 'THH'))
1169
paths.append(_join_rocm_home('include'))
1171
cuda_home_include = _join_cuda_home('include')
1172
# if we have the Debian/Ubuntu packages for cuda, we get /usr as cuda home.
1173
# but gcc doesn't like having /usr/include passed explicitly
1174
if cuda_home_include != '/usr/include':
1175
paths.append(cuda_home_include)
1177
# Support CUDA_INC_PATH env variable supported by CMake files
1178
if (cuda_inc_path := os.environ.get("CUDA_INC_PATH", None)) and \
1179
cuda_inc_path != '/usr/include':
1180
paths.append(cuda_inc_path)
1181
if CUDNN_HOME is not None:
1182
paths.append(os.path.join(CUDNN_HOME, 'include'))
1186
def library_paths(cuda: bool = False) -> List[str]:
1188
Get the library paths required to build a C++ or CUDA extension.
1191
cuda: If `True`, includes CUDA-specific library paths.
1194
A list of library path strings.
1196
# We need to link against libtorch.so
1197
paths = [TORCH_LIB_PATH]
1199
if cuda and IS_HIP_EXTENSION:
1201
paths.append(_join_rocm_home(lib_dir))
1202
if HIP_HOME is not None:
1203
paths.append(os.path.join(HIP_HOME, 'lib'))
1206
lib_dir = os.path.join('lib', 'x64')
1209
if (not os.path.exists(_join_cuda_home(lib_dir)) and
1210
os.path.exists(_join_cuda_home('lib'))):
1211
# 64-bit CUDA may be installed in 'lib' (see e.g. gh-16955)
1212
# Note that it's also possible both don't exist (see
1213
# _find_cuda_home) - in that case we stay with 'lib64'.
1216
paths.append(_join_cuda_home(lib_dir))
1217
if CUDNN_HOME is not None:
1218
paths.append(os.path.join(CUDNN_HOME, lib_dir))
1223
sources: Union[str, List[str]],
1225
extra_cuda_cflags=None,
1227
extra_include_paths=None,
1228
build_directory=None,
1230
with_cuda: Optional[bool] = None,
1231
is_python_module=True,
1232
is_standalone=False,
1233
keep_intermediates=True):
1235
Load a PyTorch C++ extension just-in-time (JIT).
1237
To load an extension, a Ninja build file is emitted, which is used to
1238
compile the given sources into a dynamic library. This library is
1239
subsequently loaded into the current Python process as a module and
1240
returned from this function, ready for use.
1242
By default, the directory to which the build file is emitted and the
1243
resulting library compiled to is ``<tmp>/torch_extensions/<name>``, where
1244
``<tmp>`` is the temporary folder on the current platform and ``<name>``
1245
the name of the extension. This location can be overridden in two ways.
1246
First, if the ``TORCH_EXTENSIONS_DIR`` environment variable is set, it
1247
replaces ``<tmp>/torch_extensions`` and all extensions will be compiled
1248
into subfolders of this directory. Second, if the ``build_directory``
1249
argument to this function is supplied, it overrides the entire path, i.e.
1250
the library will be compiled into that folder directly.
1252
To compile the sources, the default system compiler (``c++``) is used,
1253
which can be overridden by setting the ``CXX`` environment variable. To pass
1254
additional arguments to the compilation process, ``extra_cflags`` or
1255
``extra_ldflags`` can be provided. For example, to compile your extension
1256
with optimizations, pass ``extra_cflags=['-O3']``. You can also use
1257
``extra_cflags`` to pass further include directories.
1259
CUDA support with mixed compilation is provided. Simply pass CUDA source
1260
files (``.cu`` or ``.cuh``) along with other sources. Such files will be
1261
detected and compiled with nvcc rather than the C++ compiler. This includes
1262
passing the CUDA lib64 directory as a library directory, and linking
1263
``cudart``. You can pass additional flags to nvcc via
1264
``extra_cuda_cflags``, just like with ``extra_cflags`` for C++. Various
1265
heuristics for finding the CUDA install directory are used, which usually
1266
work fine. If not, setting the ``CUDA_HOME`` environment variable is the
1270
name: The name of the extension to build. This MUST be the same as the
1271
name of the pybind11 module!
1272
sources: A list of relative or absolute paths to C++ source files.
1273
extra_cflags: optional list of compiler flags to forward to the build.
1274
extra_cuda_cflags: optional list of compiler flags to forward to nvcc
1275
when building CUDA sources.
1276
extra_ldflags: optional list of linker flags to forward to the build.
1277
extra_include_paths: optional list of include directories to forward
1279
build_directory: optional path to use as build workspace.
1280
verbose: If ``True``, turns on verbose logging of load steps.
1281
with_cuda: Determines whether CUDA headers and libraries are added to
1282
the build. If set to ``None`` (default), this value is
1283
automatically determined based on the existence of ``.cu`` or
1284
``.cuh`` in ``sources``. Set it to `True`` to force CUDA headers
1285
and libraries to be included.
1286
is_python_module: If ``True`` (default), imports the produced shared
1287
library as a Python module. If ``False``, behavior depends on
1289
is_standalone: If ``False`` (default) loads the constructed extension
1290
into the process as a plain dynamic library. If ``True``, build a
1291
standalone executable.
1294
If ``is_python_module`` is ``True``:
1295
Returns the loaded PyTorch extension as a Python module.
1297
If ``is_python_module`` is ``False`` and ``is_standalone`` is ``False``:
1298
Returns nothing. (The shared library is loaded into the process as
1301
If ``is_standalone`` is ``True``.
1302
Return the path to the executable. (On Windows, TORCH_LIB_PATH is
1303
added to the PATH environment variable as a side effect.)
1306
>>> # xdoctest: +SKIP
1307
>>> from torch.utils.cpp_extension import load
1309
... name='extension',
1310
... sources=['extension.cpp', 'extension_kernel.cu'],
1311
... extra_cflags=['-O2'],
1314
return _jit_compile(
1316
[sources] if isinstance(sources, str) else sources,
1320
extra_include_paths,
1321
build_directory or _get_build_directory(name, verbose),
1326
keep_intermediates=keep_intermediates)
1328
def _get_pybind11_abi_build_flags():
1329
# Note [Pybind11 ABI constants]
1331
# Pybind11 before 2.4 used to build an ABI strings using the following pattern:
1332
# f"__pybind11_internals_v{PYBIND11_INTERNALS_VERSION}{PYBIND11_INTERNALS_KIND}{PYBIND11_BUILD_TYPE}__"
1333
# Since 2.4 compier type, stdlib and build abi parameters are also encoded like this:
1334
# f"__pybind11_internals_v{PYBIND11_INTERNALS_VERSION}{PYBIND11_INTERNALS_KIND}{PYBIND11_COMPILER_TYPE}{PYBIND11_STDLIB}{PYBIND11_BUILD_ABI}{PYBIND11_BUILD_TYPE}__"
1336
# This was done in order to further narrow down the chances of compiler ABI incompatibility
1337
# that can cause a hard to debug segfaults.
1338
# For PyTorch extensions we want to relax those restrictions and pass compiler, stdlib and abi properties
1339
# captured during PyTorch native library compilation in torch/csrc/Module.cpp
1342
for pname in ["COMPILER_TYPE", "STDLIB", "BUILD_ABI"]:
1343
pval = getattr(torch._C, f"_PYBIND11_{pname}")
1344
if pval is not None and not IS_WINDOWS:
1345
abi_cflags.append(f'-DPYBIND11_{pname}=\\"{pval}\\"')
1348
def _get_glibcxx_abi_build_flags():
1349
glibcxx_abi_cflags = ['-D_GLIBCXX_USE_CXX11_ABI=' + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))]
1350
return glibcxx_abi_cflags
1352
def check_compiler_is_gcc(compiler):
1356
env = os.environ.copy()
1357
env['LC_ALL'] = 'C' # Don't localize output
1359
version_string = subprocess.check_output([compiler, '-v'], stderr=subprocess.STDOUT, env=env).decode(*SUBPROCESS_DECODE_ARGS)
1360
except Exception as e:
1362
version_string = subprocess.check_output([compiler, '--version'], stderr=subprocess.STDOUT, env=env).decode(*SUBPROCESS_DECODE_ARGS)
1363
except Exception as e:
1365
# Check for 'gcc' or 'g++' for sccache wrapper
1366
pattern = re.compile("^COLLECT_GCC=(.*)$", re.MULTILINE)
1367
results = re.findall(pattern, version_string)
1368
if len(results) != 1:
1370
compiler_path = os.path.realpath(results[0].strip())
1371
# On RHEL/CentOS c++ is a gcc compiler wrapper
1372
if os.path.basename(compiler_path) == 'c++' and 'gcc version' in version_string:
1376
def _check_and_build_extension_h_precompiler_headers(
1378
extra_include_paths,
1379
is_standalone=False):
1381
Precompiled Headers(PCH) can pre-build the same headers and reduce build time for pytorch load_inline modules.
1382
GCC offical manual: https://gcc.gnu.org/onlinedocs/gcc-4.0.4/gcc/Precompiled-Headers.html
1383
PCH only works when built pch file(header.h.gch) and build target have the same build parameters. So, We need
1384
add a signature file to record PCH file parameters. If the build parameters(signature) changed, it should rebuild
1388
1. Windows and MacOS have different PCH mechanism. We only support Linux currently.
1389
2. It only works on GCC/G++.
1394
compiler = get_cxx_compiler()
1396
b_is_gcc = check_compiler_is_gcc(compiler)
1397
if b_is_gcc is False:
1400
head_file = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h')
1401
head_file_pch = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h.gch')
1402
head_file_signature = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h.sign')
1404
def listToString(s):
1405
# initialize an empty string
1410
# traverse in the string
1412
string += (element + ' ')
1416
def format_precompiler_header_cmd(compiler, head_file, head_file_pch, common_cflags, torch_include_dirs, extra_cflags, extra_include_paths):
1421
{compiler} -x c++-header {head_file} -o {head_file_pch} {torch_include_dirs} {extra_include_paths} {extra_cflags} {common_cflags}
1425
def command_to_signature(cmd):
1426
signature = cmd.replace(' ', '_')
1429
def check_pch_signature_in_file(file_path, signature):
1430
b_exist = os.path.isfile(file_path)
1431
if b_exist is False:
1434
with open(file_path) as file:
1435
# read all content of a file
1436
content = file.read()
1437
# check if string present in a file
1438
return signature == content
1440
def _create_if_not_exist(path_dir):
1441
if not os.path.exists(path_dir):
1443
Path(path_dir).mkdir(parents=True, exist_ok=True)
1444
except OSError as exc: # Guard against race condition
1445
if exc.errno != errno.EEXIST:
1446
raise RuntimeError(f"Fail to create path {path_dir}") from exc
1448
def write_pch_signature_to_file(file_path, pch_sign):
1449
_create_if_not_exist(os.path.dirname(file_path))
1450
with open(file_path, "w") as f:
1454
def build_precompile_header(pch_cmd):
1456
subprocess.check_output(pch_cmd, shell=True, stderr=subprocess.STDOUT)
1457
except subprocess.CalledProcessError as e:
1458
raise RuntimeError(f"Compile PreCompile Header fail, command: {pch_cmd}") from e
1460
extra_cflags_str = listToString(extra_cflags)
1461
extra_include_paths_str = " ".join(
1462
[f"-I{include}" for include in extra_include_paths] if extra_include_paths else []
1465
lib_include = os.path.join(_TORCH_PATH, 'include')
1466
torch_include_dirs = [
1467
f"-I {lib_include}",
1469
"-I {}".format(sysconfig.get_path("include")),
1471
"-I {}".format(os.path.join(lib_include, 'torch', 'csrc', 'api', 'include')),
1474
torch_include_dirs_str = listToString(torch_include_dirs)
1477
if not is_standalone:
1478
common_cflags += ['-DTORCH_API_INCLUDE_EXTENSION_H']
1480
common_cflags += ['-std=c++17', '-fPIC']
1481
common_cflags += [f"{x}" for x in _get_pybind11_abi_build_flags()]
1482
common_cflags += [f"{x}" for x in _get_glibcxx_abi_build_flags()]
1483
common_cflags_str = listToString(common_cflags)
1485
pch_cmd = format_precompiler_header_cmd(compiler, head_file, head_file_pch, common_cflags_str, torch_include_dirs_str, extra_cflags_str, extra_include_paths_str)
1486
pch_sign = command_to_signature(pch_cmd)
1488
if os.path.isfile(head_file_pch) is not True:
1489
build_precompile_header(pch_cmd)
1490
write_pch_signature_to_file(head_file_signature, pch_sign)
1492
b_same_sign = check_pch_signature_in_file(head_file_signature, pch_sign)
1493
if b_same_sign is False:
1494
build_precompile_header(pch_cmd)
1495
write_pch_signature_to_file(head_file_signature, pch_sign)
1497
def remove_extension_h_precompiler_headers():
1498
def _remove_if_file_exists(path_file):
1499
if os.path.exists(path_file):
1500
os.remove(path_file)
1502
head_file_pch = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h.gch')
1503
head_file_signature = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h.sign')
1505
_remove_if_file_exists(head_file_pch)
1506
_remove_if_file_exists(head_file_signature)
1508
def load_inline(name,
1513
extra_cuda_cflags=None,
1515
extra_include_paths=None,
1516
build_directory=None,
1519
is_python_module=True,
1520
with_pytorch_error_handling=True,
1521
keep_intermediates=True,
1524
Load a PyTorch C++ extension just-in-time (JIT) from string sources.
1526
This function behaves exactly like :func:`load`, but takes its sources as
1527
strings rather than filenames. These strings are stored to files in the
1528
build directory, after which the behavior of :func:`load_inline` is
1529
identical to :func:`load`.
1532
tests <https://github.com/pytorch/pytorch/blob/master/test/test_cpp_extensions_jit.py>`_
1533
for good examples of using this function.
1535
Sources may omit two required parts of a typical non-inline C++ extension:
1536
the necessary header includes, as well as the (pybind11) binding code. More
1537
precisely, strings passed to ``cpp_sources`` are first concatenated into a
1538
single ``.cpp`` file. This file is then prepended with ``#include
1539
<torch/extension.h>``.
1541
Furthermore, if the ``functions`` argument is supplied, bindings will be
1542
automatically generated for each function specified. ``functions`` can
1543
either be a list of function names, or a dictionary mapping from function
1544
names to docstrings. If a list is given, the name of each function is used
1547
The sources in ``cuda_sources`` are concatenated into a separate ``.cu``
1548
file and prepended with ``torch/types.h``, ``cuda.h`` and
1549
``cuda_runtime.h`` includes. The ``.cpp`` and ``.cu`` files are compiled
1550
separately, but ultimately linked into a single library. Note that no
1551
bindings are generated for functions in ``cuda_sources`` per se. To bind
1552
to a CUDA kernel, you must create a C++ function that calls it, and either
1553
declare or define this C++ function in one of the ``cpp_sources`` (and
1554
include its name in ``functions``).
1556
See :func:`load` for a description of arguments omitted below.
1559
cpp_sources: A string, or list of strings, containing C++ source code.
1560
cuda_sources: A string, or list of strings, containing CUDA source code.
1561
functions: A list of function names for which to generate function
1562
bindings. If a dictionary is given, it should map function names to
1563
docstrings (which are otherwise just the function names).
1564
with_cuda: Determines whether CUDA headers and libraries are added to
1565
the build. If set to ``None`` (default), this value is
1566
automatically determined based on whether ``cuda_sources`` is
1567
provided. Set it to ``True`` to force CUDA headers
1568
and libraries to be included.
1569
with_pytorch_error_handling: Determines whether pytorch error and
1570
warning macros are handled by pytorch instead of pybind. To do
1571
this, each function ``foo`` is called via an intermediary ``_safe_foo``
1572
function. This redirection might cause issues in obscure cases
1573
of cpp. This flag should be set to ``False`` when this redirect
1577
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT)
1578
>>> from torch.utils.cpp_extension import load_inline
1580
at::Tensor sin_add(at::Tensor x, at::Tensor y) {
1581
return x.sin() + y.sin();
1584
>>> module = load_inline(name='inline_extension',
1585
... cpp_sources=[source],
1586
... functions=['sin_add'])
1589
By default, the Ninja backend uses #CPUS + 2 workers to build the
1590
extension. This may use up too many resources on some systems. One
1591
can control the number of workers by setting the `MAX_JOBS` environment
1592
variable to a non-negative number.
1594
build_directory = build_directory or _get_build_directory(name, verbose)
1596
if isinstance(cpp_sources, str):
1597
cpp_sources = [cpp_sources]
1598
cuda_sources = cuda_sources or []
1599
if isinstance(cuda_sources, str):
1600
cuda_sources = [cuda_sources]
1602
cpp_sources.insert(0, '#include <torch/extension.h>')
1605
# Using PreCompile Header('torch/extension.h') to reduce compile time.
1606
_check_and_build_extension_h_precompiler_headers(extra_cflags, extra_include_paths)
1608
remove_extension_h_precompiler_headers()
1610
# If `functions` is supplied, we create the pybind11 bindings for the user.
1611
# Here, `functions` is (or becomes, after some processing) a map from
1612
# function names to function docstrings.
1613
if functions is not None:
1615
module_def.append('PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {')
1616
if isinstance(functions, str):
1617
functions = [functions]
1618
if isinstance(functions, list):
1619
# Make the function docstring the same as the function name.
1620
functions = {f: f for f in functions}
1621
elif not isinstance(functions, dict):
1622
raise ValueError(f"Expected 'functions' to be a list or dict, but was {type(functions)}")
1623
for function_name, docstring in functions.items():
1624
if with_pytorch_error_handling:
1625
module_def.append(f'm.def("{function_name}", torch::wrap_pybind_function({function_name}), "{docstring}");')
1627
module_def.append(f'm.def("{function_name}", {function_name}, "{docstring}");')
1628
module_def.append('}')
1629
cpp_sources += module_def
1631
cpp_source_path = os.path.join(build_directory, 'main.cpp')
1632
_maybe_write(cpp_source_path, "\n".join(cpp_sources))
1634
sources = [cpp_source_path]
1637
cuda_sources.insert(0, '#include <torch/types.h>')
1638
cuda_sources.insert(1, '#include <cuda.h>')
1639
cuda_sources.insert(2, '#include <cuda_runtime.h>')
1641
cuda_source_path = os.path.join(build_directory, 'cuda.cu')
1642
_maybe_write(cuda_source_path, "\n".join(cuda_sources))
1644
sources.append(cuda_source_path)
1646
return _jit_compile(
1652
extra_include_paths,
1657
is_standalone=False,
1658
keep_intermediates=keep_intermediates)
1661
def _jit_compile(name,
1666
extra_include_paths,
1667
build_directory: str,
1669
with_cuda: Optional[bool],
1672
keep_intermediates=True) -> None:
1673
if is_python_module and is_standalone:
1674
raise ValueError("`is_python_module` and `is_standalone` are mutually exclusive.")
1676
if with_cuda is None:
1677
with_cuda = any(map(_is_cuda_file, sources))
1678
with_cudnn = any('cudnn' in f for f in extra_ldflags or [])
1679
old_version = JIT_EXTENSION_VERSIONER.get_version(name)
1680
version = JIT_EXTENSION_VERSIONER.bump_version_if_changed(
1683
build_arguments=[extra_cflags, extra_cuda_cflags, extra_ldflags, extra_include_paths],
1684
build_directory=build_directory,
1685
with_cuda=with_cuda,
1686
is_python_module=is_python_module,
1687
is_standalone=is_standalone,
1690
if version != old_version and verbose:
1691
print(f'The input conditions for extension module {name} have changed. ' +
1692
f'Bumping to version {version} and re-building as {name}_v{version}...',
1694
name = f'{name}_v{version}'
1696
baton = FileBaton(os.path.join(build_directory, 'lock'))
1697
if baton.try_acquire():
1699
if version != old_version:
1700
with GeneratedFileCleaner(keep_intermediates=keep_intermediates) as clean_ctx:
1701
if IS_HIP_EXTENSION and (with_cuda or with_cudnn):
1702
hipify_result = hipify_python.hipify(
1703
project_directory=build_directory,
1704
output_directory=build_directory,
1705
header_include_dirs=(extra_include_paths if extra_include_paths is not None else []),
1706
extra_files=[os.path.abspath(s) for s in sources],
1707
ignores=[_join_rocm_home('*'), os.path.join(_TORCH_PATH, '*')], # no need to hipify ROCm or PyTorch headers
1708
show_detailed=verbose,
1709
show_progress=verbose,
1710
is_pytorch_extension=True,
1714
hipified_sources = set()
1715
for source in sources:
1716
s_abs = os.path.abspath(source)
1717
hipified_sources.add(hipify_result[s_abs].hipified_path if s_abs in hipify_result else s_abs)
1719
sources = list(hipified_sources)
1721
_write_ninja_file_and_build_library(
1724
extra_cflags=extra_cflags or [],
1725
extra_cuda_cflags=extra_cuda_cflags or [],
1726
extra_ldflags=extra_ldflags or [],
1727
extra_include_paths=extra_include_paths or [],
1728
build_directory=build_directory,
1730
with_cuda=with_cuda,
1731
is_standalone=is_standalone)
1733
print('No modifications detected for re-loaded extension '
1734
f'module {name}, skipping build step...', file=sys.stderr)
1741
print(f'Loading extension module {name}...', file=sys.stderr)
1744
return _get_exec_path(name, build_directory)
1746
return _import_module_from_library(name, build_directory, is_python_module)
1749
def _write_ninja_file_and_compile_objects(
1756
cuda_dlink_post_cflags,
1757
build_directory: str,
1759
with_cuda: Optional[bool]) -> None:
1760
verify_ninja_availability()
1762
compiler = get_cxx_compiler()
1764
get_compiler_abi_compatibility_and_version(compiler)
1765
if with_cuda is None:
1766
with_cuda = any(map(_is_cuda_file, sources))
1767
build_file_path = os.path.join(build_directory, 'build.ninja')
1769
print(f'Emitting ninja build file {build_file_path}...', file=sys.stderr)
1771
path=build_file_path,
1773
post_cflags=post_cflags,
1774
cuda_cflags=cuda_cflags,
1775
cuda_post_cflags=cuda_post_cflags,
1776
cuda_dlink_post_cflags=cuda_dlink_post_cflags,
1780
library_target=None,
1781
with_cuda=with_cuda)
1783
print('Compiling objects...', file=sys.stderr)
1787
# It would be better if we could tell users the name of the extension
1788
# that failed to build but there isn't a good way to get it here.
1789
error_prefix='Error compiling objects for extension')
1792
def _write_ninja_file_and_build_library(
1798
extra_include_paths,
1799
build_directory: str,
1801
with_cuda: Optional[bool],
1802
is_standalone: bool = False) -> None:
1803
verify_ninja_availability()
1805
compiler = get_cxx_compiler()
1807
get_compiler_abi_compatibility_and_version(compiler)
1808
if with_cuda is None:
1809
with_cuda = any(map(_is_cuda_file, sources))
1810
extra_ldflags = _prepare_ldflags(
1811
extra_ldflags or [],
1815
build_file_path = os.path.join(build_directory, 'build.ninja')
1817
print(f'Emitting ninja build file {build_file_path}...', file=sys.stderr)
1818
# NOTE: Emitting a new ninja build file does not cause re-compilation if
1819
# the sources did not change, so it's ok to re-emit (and it's fast).
1820
_write_ninja_file_to_build_library(
1821
path=build_file_path,
1824
extra_cflags=extra_cflags or [],
1825
extra_cuda_cflags=extra_cuda_cflags or [],
1826
extra_ldflags=extra_ldflags or [],
1827
extra_include_paths=extra_include_paths or [],
1828
with_cuda=with_cuda,
1829
is_standalone=is_standalone)
1832
print(f'Building extension module {name}...', file=sys.stderr)
1836
error_prefix=f"Error building extension '{name}'")
1839
def is_ninja_available():
1840
"""Return ``True`` if the `ninja <https://ninja-build.org/>`_ build system is available on the system, ``False`` otherwise."""
1842
subprocess.check_output('ninja --version'.split())
1849
def verify_ninja_availability():
1850
"""Raise ``RuntimeError`` if `ninja <https://ninja-build.org/>`_ build system is not available on the system, does nothing otherwise."""
1851
if not is_ninja_available():
1852
raise RuntimeError("Ninja is required to load C++ extensions")
1855
def _prepare_ldflags(extra_ldflags, with_cuda, verbose, is_standalone):
1857
python_lib_path = os.path.join(sys.base_exec_prefix, 'libs')
1859
extra_ldflags.append('c10.lib')
1861
extra_ldflags.append('c10_cuda.lib')
1862
extra_ldflags.append('torch_cpu.lib')
1864
extra_ldflags.append('torch_cuda.lib')
1865
# /INCLUDE is used to ensure torch_cuda is linked against in a project that relies on it.
1866
# Related issue: https://github.com/pytorch/pytorch/issues/31611
1867
extra_ldflags.append('-INCLUDE:?warp_size@cuda@at@@YAHXZ')
1868
extra_ldflags.append('torch.lib')
1869
extra_ldflags.append(f'/LIBPATH:{TORCH_LIB_PATH}')
1870
if not is_standalone:
1871
extra_ldflags.append('torch_python.lib')
1872
extra_ldflags.append(f'/LIBPATH:{python_lib_path}')
1875
extra_ldflags.append(f'-L{TORCH_LIB_PATH}')
1876
extra_ldflags.append('-lc10')
1878
extra_ldflags.append('-lc10_hip' if IS_HIP_EXTENSION else '-lc10_cuda')
1879
extra_ldflags.append('-ltorch_cpu')
1881
extra_ldflags.append('-ltorch_hip' if IS_HIP_EXTENSION else '-ltorch_cuda')
1882
extra_ldflags.append('-ltorch')
1883
if not is_standalone:
1884
extra_ldflags.append('-ltorch_python')
1887
extra_ldflags.append(f"-Wl,-rpath,{TORCH_LIB_PATH}")
1891
print('Detected CUDA files, patching ldflags', file=sys.stderr)
1893
extra_ldflags.append(f'/LIBPATH:{_join_cuda_home("lib", "x64")}')
1894
extra_ldflags.append('cudart.lib')
1895
if CUDNN_HOME is not None:
1896
extra_ldflags.append(f'/LIBPATH:{os.path.join(CUDNN_HOME, "lib", "x64")}')
1897
elif not IS_HIP_EXTENSION:
1898
extra_lib_dir = "lib64"
1899
if (not os.path.exists(_join_cuda_home(extra_lib_dir)) and
1900
os.path.exists(_join_cuda_home("lib"))):
1901
# 64-bit CUDA may be installed in "lib"
1902
# Note that it's also possible both don't exist (see _find_cuda_home) - in that case we stay with "lib64"
1903
extra_lib_dir = "lib"
1904
extra_ldflags.append(f'-L{_join_cuda_home(extra_lib_dir)}')
1905
extra_ldflags.append('-lcudart')
1906
if CUDNN_HOME is not None:
1907
extra_ldflags.append(f'-L{os.path.join(CUDNN_HOME, "lib64")}')
1908
elif IS_HIP_EXTENSION:
1909
extra_ldflags.append(f'-L{_join_rocm_home("lib")}')
1910
extra_ldflags.append('-lamdhip64')
1911
return extra_ldflags
1914
def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]:
1916
Determine CUDA arch flags to use.
1918
For an arch, say "6.1", the added compile flag will be
1919
``-gencode=arch=compute_61,code=sm_61``.
1920
For an added "+PTX", an additional
1921
``-gencode=arch=compute_xx,code=compute_xx`` is added.
1923
See select_compute_arch.cmake for corresponding named and supported arches
1924
when building with CMake.
1926
# If cflags is given, there may already be user-provided arch flags in it
1927
# (from `extra_compile_args`)
1928
if cflags is not None:
1930
if 'TORCH_EXTENSION_NAME' in flag:
1935
# Note: keep combined names ("arch1+arch2") above single names, otherwise
1936
# string replacement may not do the right thing
1937
named_arches = collections.OrderedDict([
1938
('Kepler+Tesla', '3.7'),
1939
('Kepler', '3.5+PTX'),
1940
('Maxwell+Tegra', '5.3'),
1941
('Maxwell', '5.0;5.2+PTX'),
1942
('Pascal', '6.0;6.1+PTX'),
1943
('Volta+Tegra', '7.2'),
1944
('Volta', '7.0+PTX'),
1945
('Turing', '7.5+PTX'),
1946
('Ampere+Tegra', '8.7'),
1947
('Ampere', '8.0;8.6+PTX'),
1949
('Hopper', '9.0+PTX'),
1952
supported_arches = ['3.5', '3.7', '5.0', '5.2', '5.3', '6.0', '6.1', '6.2',
1953
'7.0', '7.2', '7.5', '8.0', '8.6', '8.7', '8.9', '9.0', '9.0a']
1954
valid_arch_strings = supported_arches + [s + "+PTX" for s in supported_arches]
1956
# The default is sm_30 for CUDA 9.x and 10.x
1957
# First check for an env var (same as used by the main setup.py)
1958
# Can be one or more architectures, e.g. "6.1" or "3.5;5.2;6.0;6.1;7.0+PTX"
1959
# See cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake
1960
_arch_list = os.environ.get('TORCH_CUDA_ARCH_LIST', None)
1962
# If not given, determine what's best for the GPU / CUDA version that can be found
1965
"TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. \n"
1966
"If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].")
1968
# the assumption is that the extension should run on any of the currently visible cards,
1969
# which could be of different types - therefore all archs for visible cards should be included
1970
for i in range(torch.cuda.device_count()):
1971
capability = torch.cuda.get_device_capability(i)
1972
supported_sm = [int(arch.split('_')[1])
1973
for arch in torch.cuda.get_arch_list() if 'sm_' in arch]
1974
max_supported_sm = max((sm // 10, sm % 10) for sm in supported_sm)
1975
# Capability of the device may be higher than what's supported by the user's
1976
# NVCC, causing compilation error. User's NVCC is expected to match the one
1977
# used to build pytorch, so we use the maximum supported capability of pytorch
1978
# to clamp the capability.
1979
capability = min(max_supported_sm, capability)
1980
arch = f'{capability[0]}.{capability[1]}'
1981
if arch not in arch_list:
1982
arch_list.append(arch)
1983
arch_list = sorted(arch_list)
1984
arch_list[-1] += '+PTX'
1986
# Deal with lists that are ' ' separated (only deal with ';' after)
1987
_arch_list = _arch_list.replace(' ', ';')
1988
# Expand named arches
1989
for named_arch, archval in named_arches.items():
1990
_arch_list = _arch_list.replace(named_arch, archval)
1992
arch_list = _arch_list.split(';')
1995
for arch in arch_list:
1996
if arch not in valid_arch_strings:
1997
raise ValueError(f"Unknown CUDA arch ({arch}) or GPU not supported")
1999
num = arch[0] + arch[2:].split("+")[0]
2000
flags.append(f'-gencode=arch=compute_{num},code=sm_{num}')
2001
if arch.endswith('+PTX'):
2002
flags.append(f'-gencode=arch=compute_{num},code=compute_{num}')
2004
return sorted(set(flags))
2007
def _get_rocm_arch_flags(cflags: Optional[List[str]] = None) -> List[str]:
2008
# If cflags is given, there may already be user-provided arch flags in it
2009
# (from `extra_compile_args`)
2010
if cflags is not None:
2012
if 'amdgpu-target' in flag or 'offload-arch' in flag:
2013
return ['-fno-gpu-rdc']
2014
# Use same defaults as used for building PyTorch
2015
# Allow env var to override, just like during initial cmake build.
2016
_archs = os.environ.get('PYTORCH_ROCM_ARCH', None)
2018
archFlags = torch._C._cuda_getArchFlags()
2020
archs = archFlags.split()
2024
archs = _archs.replace(' ', ';').split(';')
2025
flags = [f'--offload-arch={arch}' for arch in archs]
2026
flags += ['-fno-gpu-rdc']
2029
def _get_build_directory(name: str, verbose: bool) -> str:
2030
root_extensions_directory = os.environ.get('TORCH_EXTENSIONS_DIR')
2031
if root_extensions_directory is None:
2032
root_extensions_directory = get_default_build_root()
2033
cu_str = ('cpu' if torch.version.cuda is None else
2034
f'cu{torch.version.cuda.replace(".", "")}') # type: ignore[attr-defined]
2035
python_version = f'py{sys.version_info.major}{sys.version_info.minor}'
2036
build_folder = f'{python_version}_{cu_str}'
2038
root_extensions_directory = os.path.join(
2039
root_extensions_directory, build_folder)
2042
print(f'Using {root_extensions_directory} as PyTorch extensions root...', file=sys.stderr)
2044
build_directory = os.path.join(root_extensions_directory, name)
2045
if not os.path.exists(build_directory):
2047
print(f'Creating extension directory {build_directory}...', file=sys.stderr)
2048
# This is like mkdir -p, i.e. will also create parent directories.
2049
os.makedirs(build_directory, exist_ok=True)
2051
return build_directory
2054
def _get_num_workers(verbose: bool) -> Optional[int]:
2055
max_jobs = os.environ.get('MAX_JOBS')
2056
if max_jobs is not None and max_jobs.isdigit():
2058
print(f'Using envvar MAX_JOBS ({max_jobs}) as the number of workers...',
2060
return int(max_jobs)
2062
print('Allowing ninja to set a default number of workers... '
2063
'(overridable by setting the environment variable MAX_JOBS=N)',
2068
def _run_ninja_build(build_directory: str, verbose: bool, error_prefix: str) -> None:
2069
command = ['ninja', '-v']
2070
num_workers = _get_num_workers(verbose)
2071
if num_workers is not None:
2072
command.extend(['-j', str(num_workers)])
2073
env = os.environ.copy()
2074
# Try to activate the vc env for the users
2075
if IS_WINDOWS and 'VSCMD_ARG_TGT_ARCH' not in env:
2076
from setuptools import distutils
2078
plat_name = distutils.util.get_platform()
2079
plat_spec = PLAT_TO_VCVARS[plat_name]
2081
vc_env = distutils._msvccompiler._get_vc_env(plat_spec)
2082
vc_env = {k.upper(): v for k, v in vc_env.items()}
2083
for k, v in env.items():
2085
if uk not in vc_env:
2091
# Warning: don't pass stdout=None to subprocess.run to get output.
2092
# subprocess.run assumes that sys.__stdout__ has not been modified and
2093
# attempts to write to it by default. However, when we call _run_ninja_build
2094
# from ahead-of-time cpp extensions, the following happens:
2095
# 1) If the stdout encoding is not utf-8, setuptools detachs __stdout__.
2096
# https://github.com/pypa/setuptools/blob/7e97def47723303fafabe48b22168bbc11bb4821/setuptools/dist.py#L1110
2097
# (it probably shouldn't do this)
2098
# 2) subprocess.run (on POSIX, with no stdout override) relies on
2099
# __stdout__ not being detached:
2100
# https://github.com/python/cpython/blob/c352e6c7446c894b13643f538db312092b351789/Lib/subprocess.py#L1214
2101
# To work around this, we pass in the fileno directly and hope that
2106
stdout=stdout_fileno if verbose else subprocess.PIPE,
2107
stderr=subprocess.STDOUT,
2108
cwd=build_directory,
2111
except subprocess.CalledProcessError as e:
2112
# Python 2 and 3 compatible way of getting the error object.
2113
_, error, _ = sys.exc_info()
2114
# error.output contains the stdout and stderr of the build attempt.
2115
message = error_prefix
2116
# `error` is a CalledProcessError (which has an `output`) attribute, but
2117
# mypy thinks it's Optional[BaseException] and doesn't narrow
2118
if hasattr(error, 'output') and error.output: # type: ignore[union-attr]
2119
message += f": {error.output.decode(*SUBPROCESS_DECODE_ARGS)}" # type: ignore[union-attr]
2120
raise RuntimeError(message) from e
2123
def _get_exec_path(module_name, path):
2124
if IS_WINDOWS and TORCH_LIB_PATH not in os.getenv('PATH', '').split(';'):
2125
torch_lib_in_path = any(
2126
os.path.exists(p) and os.path.samefile(p, TORCH_LIB_PATH)
2127
for p in os.getenv('PATH', '').split(';')
2129
if not torch_lib_in_path:
2130
os.environ['PATH'] = f"{TORCH_LIB_PATH};{os.getenv('PATH', '')}"
2131
return os.path.join(path, f'{module_name}{EXEC_EXT}')
2134
def _import_module_from_library(module_name, path, is_python_module):
2135
filepath = os.path.join(path, f"{module_name}{LIB_EXT}")
2136
if is_python_module:
2137
# https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path
2138
spec = importlib.util.spec_from_file_location(module_name, filepath)
2139
assert spec is not None
2140
module = importlib.util.module_from_spec(spec)
2141
assert isinstance(spec.loader, importlib.abc.Loader)
2142
spec.loader.exec_module(module)
2145
torch.ops.load_library(filepath)
2149
def _write_ninja_file_to_build_library(path,
2155
extra_include_paths,
2157
is_standalone) -> None:
2158
extra_cflags = [flag.strip() for flag in extra_cflags]
2159
extra_cuda_cflags = [flag.strip() for flag in extra_cuda_cflags]
2160
extra_ldflags = [flag.strip() for flag in extra_ldflags]
2161
extra_include_paths = [flag.strip() for flag in extra_include_paths]
2163
# Turn into absolute paths so we can emit them into the ninja build
2164
# file wherever it is.
2165
user_includes = [os.path.abspath(file) for file in extra_include_paths]
2167
# include_paths() gives us the location of torch/extension.h
2168
system_includes = include_paths(with_cuda)
2169
# sysconfig.get_path('include') gives us the location of Python.h
2170
# Explicitly specify 'posix_prefix' scheme on non-Windows platforms to workaround error on some MacOS
2171
# installations where default `get_path` points to non-existing `/Library/Python/M.m/include` folder
2172
python_include_path = sysconfig.get_path('include', scheme='nt' if IS_WINDOWS else 'posix_prefix')
2173
if python_include_path is not None:
2174
system_includes.append(python_include_path)
2177
if not is_standalone:
2178
common_cflags.append(f'-DTORCH_EXTENSION_NAME={name}')
2179
common_cflags.append('-DTORCH_API_INCLUDE_EXTENSION_H')
2181
common_cflags += [f"{x}" for x in _get_pybind11_abi_build_flags()]
2183
# Windows does not understand `-isystem` and quotes flags later.
2185
common_cflags += [f'-I{include}' for include in user_includes + system_includes]
2187
common_cflags += [f'-I{shlex.quote(include)}' for include in user_includes]
2188
common_cflags += [f'-isystem {shlex.quote(include)}' for include in system_includes]
2190
common_cflags += [f"{x}" for x in _get_glibcxx_abi_build_flags()]
2193
cflags = common_cflags + COMMON_MSVC_FLAGS + ['/std:c++17'] + extra_cflags
2194
cflags = _nt_quote_args(cflags)
2196
cflags = common_cflags + ['-fPIC', '-std=c++17'] + extra_cflags
2198
if with_cuda and IS_HIP_EXTENSION:
2199
cuda_flags = ['-DWITH_HIP'] + cflags + COMMON_HIP_FLAGS + COMMON_HIPCC_FLAGS
2200
cuda_flags += extra_cuda_cflags
2201
cuda_flags += _get_rocm_arch_flags(cuda_flags)
2203
cuda_flags = common_cflags + COMMON_NVCC_FLAGS + _get_cuda_arch_flags()
2205
for flag in COMMON_MSVC_FLAGS:
2206
cuda_flags = ['-Xcompiler', flag] + cuda_flags
2207
for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS:
2208
cuda_flags = ['-Xcudafe', '--diag_suppress=' + ignore_warning] + cuda_flags
2209
cuda_flags = cuda_flags + ['-std=c++17']
2210
cuda_flags = _nt_quote_args(cuda_flags)
2211
cuda_flags += _nt_quote_args(extra_cuda_cflags)
2213
cuda_flags += ['--compiler-options', "'-fPIC'"]
2214
cuda_flags += extra_cuda_cflags
2215
if not any(flag.startswith('-std=') for flag in cuda_flags):
2216
cuda_flags.append('-std=c++17')
2217
cc_env = os.getenv("CC")
2218
if cc_env is not None:
2219
cuda_flags = ['-ccbin', cc_env] + cuda_flags
2223
def object_file_path(source_file: str) -> str:
2224
# '/path/to/file.cpp' -> 'file'
2225
file_name = os.path.splitext(os.path.basename(source_file))[0]
2226
if _is_cuda_file(source_file) and with_cuda:
2227
# Use a different object filename in case a C++ and CUDA file have
2228
# the same filename but different extension (.cpp vs. .cu).
2229
target = f'{file_name}.cuda.o'
2231
target = f'{file_name}.o'
2234
objects = [object_file_path(src) for src in sources]
2235
ldflags = ([] if is_standalone else [SHARED_FLAG]) + extra_ldflags
2237
# The darwin linker needs explicit consent to ignore unresolved symbols.
2239
ldflags.append('-undefined dynamic_lookup')
2241
ldflags = _nt_quote_args(ldflags)
2243
ext = EXEC_EXT if is_standalone else LIB_EXT
2244
library_target = f'{name}{ext}'
2250
cuda_cflags=cuda_flags,
2251
cuda_post_cflags=None,
2252
cuda_dlink_post_cflags=None,
2256
library_target=library_target,
2257
with_cuda=with_cuda)
2260
def _write_ninja_file(path,
2265
cuda_dlink_post_cflags,
2271
r"""Write a ninja file that does the desired compiling and linking.
2273
`path`: Where to write this file
2274
`cflags`: list of flags to pass to $cxx. Can be None.
2275
`post_cflags`: list of flags to append to the $cxx invocation. Can be None.
2276
`cuda_cflags`: list of flags to pass to $nvcc. Can be None.
2277
`cuda_postflags`: list of flags to append to the $nvcc invocation. Can be None.
2278
`sources`: list of paths to source files
2279
`objects`: list of desired paths to objects, one per source.
2280
`ldflags`: list of flags to pass to linker. Can be None.
2281
`library_target`: Name of the output library. Can be None; in that case,
2283
`with_cuda`: If we should be compiling with CUDA.
2285
def sanitize_flags(flags):
2289
return [flag.strip() for flag in flags]
2291
cflags = sanitize_flags(cflags)
2292
post_cflags = sanitize_flags(post_cflags)
2293
cuda_cflags = sanitize_flags(cuda_cflags)
2294
cuda_post_cflags = sanitize_flags(cuda_post_cflags)
2295
cuda_dlink_post_cflags = sanitize_flags(cuda_dlink_post_cflags)
2296
ldflags = sanitize_flags(ldflags)
2299
assert len(sources) == len(objects)
2300
assert len(sources) > 0
2302
compiler = get_cxx_compiler()
2304
# Version 1.3 is required for the `deps` directive.
2305
config = ['ninja_required_version = 1.3']
2306
config.append(f'cxx = {compiler}')
2307
if with_cuda or cuda_dlink_post_cflags:
2308
if "PYTORCH_NVCC" in os.environ:
2309
nvcc = os.getenv("PYTORCH_NVCC") # user can set nvcc compiler with ccache using the environment variable here
2311
if IS_HIP_EXTENSION:
2312
nvcc = _join_rocm_home('bin', 'hipcc')
2314
nvcc = _join_cuda_home('bin', 'nvcc')
2315
config.append(f'nvcc = {nvcc}')
2317
if IS_HIP_EXTENSION:
2318
post_cflags = COMMON_HIP_FLAGS + post_cflags
2319
flags = [f'cflags = {" ".join(cflags)}']
2320
flags.append(f'post_cflags = {" ".join(post_cflags)}')
2322
flags.append(f'cuda_cflags = {" ".join(cuda_cflags)}')
2323
flags.append(f'cuda_post_cflags = {" ".join(cuda_post_cflags)}')
2324
flags.append(f'cuda_dlink_post_cflags = {" ".join(cuda_dlink_post_cflags)}')
2325
flags.append(f'ldflags = {" ".join(ldflags)}')
2327
# Turn into absolute paths so we can emit them into the ninja build
2328
# file wherever it is.
2329
sources = [os.path.abspath(file) for file in sources]
2331
# See https://ninja-build.org/build.ninja.html for reference.
2332
compile_rule = ['rule compile']
2334
compile_rule.append(
2335
' command = cl /showIncludes $cflags -c $in /Fo$out $post_cflags')
2336
compile_rule.append(' deps = msvc')
2338
compile_rule.append(
2339
' command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags')
2340
compile_rule.append(' depfile = $out.d')
2341
compile_rule.append(' deps = gcc')
2344
cuda_compile_rule = ['rule cuda_compile']
2346
# --generate-dependencies-with-compile is not supported by ROCm
2347
# Nvcc flag `--generate-dependencies-with-compile` is not supported by sccache, which may increase build time.
2348
if torch.version.cuda is not None and os.getenv('TORCH_EXTENSION_SKIP_NVCC_GEN_DEPENDENCIES', '0') != '1':
2349
cuda_compile_rule.append(' depfile = $out.d')
2350
cuda_compile_rule.append(' deps = gcc')
2351
# Note: non-system deps with nvcc are only supported
2352
# on Linux so use --generate-dependencies-with-compile
2353
# to make this work on Windows too.
2354
nvcc_gendeps = '--generate-dependencies-with-compile --dependency-output $out.d'
2355
cuda_compile_rule.append(
2356
f' command = $nvcc {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags')
2358
# Emit one build rule per source to enable incremental build.
2360
for source_file, object_file in zip(sources, objects):
2361
is_cuda_source = _is_cuda_file(source_file) and with_cuda
2362
rule = 'cuda_compile' if is_cuda_source else 'compile'
2364
source_file = source_file.replace(':', '$:')
2365
object_file = object_file.replace(':', '$:')
2366
source_file = source_file.replace(" ", "$ ")
2367
object_file = object_file.replace(" ", "$ ")
2368
build.append(f'build {object_file}: {rule} {source_file}')
2370
if cuda_dlink_post_cflags:
2371
devlink_out = os.path.join(os.path.dirname(objects[0]), 'dlink.o')
2372
devlink_rule = ['rule cuda_devlink']
2373
devlink_rule.append(' command = $nvcc $in -o $out $cuda_dlink_post_cflags')
2374
devlink = [f'build {devlink_out}: cuda_devlink {" ".join(objects)}']
2375
objects += [devlink_out]
2377
devlink_rule, devlink = [], []
2379
if library_target is not None:
2380
link_rule = ['rule link']
2382
cl_paths = subprocess.check_output(['where',
2383
'cl']).decode(*SUBPROCESS_DECODE_ARGS).split('\r\n')
2384
if len(cl_paths) >= 1:
2385
cl_path = os.path.dirname(cl_paths[0]).replace(':', '$:')
2387
raise RuntimeError("MSVC is required to load C++ extensions")
2388
link_rule.append(f' command = "{cl_path}/link.exe" $in /nologo $ldflags /out:$out')
2390
link_rule.append(' command = $cxx $in $ldflags -o $out')
2392
link = [f'build {library_target}: link {" ".join(objects)}']
2394
default = [f'default {library_target}']
2396
link_rule, link, default = [], [], []
2398
# 'Blocks' should be separated by newlines, for visual benefit.
2399
blocks = [config, flags, compile_rule]
2401
blocks.append(cuda_compile_rule) # type: ignore[possibly-undefined]
2402
blocks += [devlink_rule, link_rule, build, devlink, link, default]
2403
content = "\n\n".join("\n".join(b) for b in blocks)
2404
# Ninja requires a new lines at the end of the .ninja file
2406
_maybe_write(path, content)
2408
def _join_cuda_home(*paths) -> str:
2410
Join paths with CUDA_HOME, or raises an error if it CUDA_HOME is not set.
2412
This is basically a lazy way of raising an error for missing $CUDA_HOME
2413
only once we need to get any CUDA-specific path.
2415
if CUDA_HOME is None:
2416
raise OSError('CUDA_HOME environment variable is not set. '
2417
'Please set it to your CUDA install root.')
2418
return os.path.join(CUDA_HOME, *paths)
2421
def _is_cuda_file(path: str) -> bool:
2422
valid_ext = ['.cu', '.cuh']
2423
if IS_HIP_EXTENSION:
2424
valid_ext.append('.hip')
2425
return os.path.splitext(path)[1] in valid_ext