StyleFeatureEditor
149 строк · 5.9 Кб
1# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2#
3# NVIDIA CORPORATION and its licensors retain all intellectual property
4# and proprietary rights in and to this software, related documentation
5# and any modifications thereto. Any use, reproduction, disclosure or
6# distribution of this software and related documentation without an express
7# license agreement from NVIDIA CORPORATION is strictly prohibited.
8
9import glob10import hashlib11import importlib12import os13import shutil14from pathlib import Path15
16import torch17import torch.utils.cpp_extension18from torch.utils.file_baton import FileBaton19
20
21# ----------------------------------------------------------------------------
22# Global options.
23
24verbosity = "brief" # Verbosity level: 'none', 'brief', 'full'25
26# ----------------------------------------------------------------------------
27# Internal helper funcs.
28
29
30def _find_compiler_bindir():31patterns = [32"C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64",33"C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64",34"C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64",35"C:/Program Files (x86)/Microsoft Visual Studio */vc/bin",36]37for pattern in patterns:38matches = sorted(glob.glob(pattern))39if len(matches):40return matches[-1]41return None42
43
44# ----------------------------------------------------------------------------
45# Main entry point for compiling and loading C++/CUDA plugins.
46
47_cached_plugins = dict()48
49
50def get_plugin(module_name, sources, **build_kwargs):51assert verbosity in ["none", "brief", "full"]52
53# Already cached?54if module_name in _cached_plugins:55return _cached_plugins[module_name]56
57# Print status.58if verbosity == "full":59print(f'Setting up PyTorch plugin "{module_name}"...')60elif verbosity == "brief":61print(f'Setting up PyTorch plugin "{module_name}"... ', end="", flush=True)62
63try: # pylint: disable=too-many-nested-blocks64# Make sure we can find the necessary compiler binaries.65if os.name == "nt" and os.system("where cl.exe >nul 2>nul") != 0:66compiler_bindir = _find_compiler_bindir()67if compiler_bindir is None:68raise RuntimeError(69f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".'70)71os.environ["PATH"] += ";" + compiler_bindir72
73# Compile and load.74verbose_build = verbosity == "full"75
76# Incremental build md5sum trickery. Copies all the input source files77# into a cached build directory under a combined md5 digest of the input78# source files. Copying is done only if the combined digest has changed.79# This keeps input file timestamps and filenames the same as in previous80# extension builds, allowing for fast incremental rebuilds.81#82# This optimization is done only in case all the source files reside in83# a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR84# environment variable is set (we take this as a signal that the user85# actually cares about this.)86source_dirs_set = set(os.path.dirname(source) for source in sources)87if len(source_dirs_set) == 1 and ("TORCH_EXTENSIONS_DIR" in os.environ):88all_source_files = sorted(89list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())90)91
92# Compute a combined hash digest for all source files in the same93# custom op directory (usually .cu, .cpp, .py and .h files).94hash_md5 = hashlib.md5()95for src in all_source_files:96with open(src, "rb") as f:97hash_md5.update(f.read())98build_dir = torch.utils.cpp_extension._get_build_directory(99module_name, verbose=verbose_build100) # pylint: disable=protected-access101digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())102
103if not os.path.isdir(digest_build_dir):104os.makedirs(digest_build_dir, exist_ok=True)105baton = FileBaton(os.path.join(digest_build_dir, "lock"))106if baton.try_acquire():107try:108for src in all_source_files:109shutil.copyfile(110src,111os.path.join(digest_build_dir, os.path.basename(src)),112)113finally:114baton.release()115else:116# Someone else is copying source files under the digest dir,117# wait until done and continue.118baton.wait()119digest_sources = [120os.path.join(digest_build_dir, os.path.basename(x)) for x in sources121]122torch.utils.cpp_extension.load(123name=module_name,124build_directory=build_dir,125verbose=verbose_build,126sources=digest_sources,127**build_kwargs,128)129else:130torch.utils.cpp_extension.load(131name=module_name, verbose=verbose_build, sources=sources, **build_kwargs132)133module = importlib.import_module(module_name)134
135except Exception:136if verbosity == "brief":137print("Failed!")138raise139
140# Print status and add to cache.141if verbosity == "full":142print(f'Done setting up PyTorch plugin "{module_name}".')143elif verbosity == "brief":144print("Done.")145_cached_plugins[module_name] = module146return module147
148
149# ----------------------------------------------------------------------------
150