StyleFeatureEditor

Форк
0
149 строк · 5.9 Кб
1
# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
2
#
3
# NVIDIA CORPORATION and its licensors retain all intellectual property
4
# and proprietary rights in and to this software, related documentation
5
# and any modifications thereto.  Any use, reproduction, disclosure or
6
# distribution of this software and related documentation without an express
7
# license agreement from NVIDIA CORPORATION is strictly prohibited.
8

9
import glob
10
import hashlib
11
import importlib
12
import os
13
import shutil
14
from pathlib import Path
15

16
import torch
17
import torch.utils.cpp_extension
18
from torch.utils.file_baton import FileBaton
19

20

21
# ----------------------------------------------------------------------------
22
# Global options.
23

24
verbosity = "brief"  # Verbosity level: 'none', 'brief', 'full'
25

26
# ----------------------------------------------------------------------------
27
# Internal helper funcs.
28

29

30
def _find_compiler_bindir():
31
    patterns = [
32
        "C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64",
33
        "C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64",
34
        "C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64",
35
        "C:/Program Files (x86)/Microsoft Visual Studio */vc/bin",
36
    ]
37
    for pattern in patterns:
38
        matches = sorted(glob.glob(pattern))
39
        if len(matches):
40
            return matches[-1]
41
    return None
42

43

44
# ----------------------------------------------------------------------------
45
# Main entry point for compiling and loading C++/CUDA plugins.
46

47
_cached_plugins = dict()
48

49

50
def get_plugin(module_name, sources, **build_kwargs):
51
    assert verbosity in ["none", "brief", "full"]
52

53
    # Already cached?
54
    if module_name in _cached_plugins:
55
        return _cached_plugins[module_name]
56

57
    # Print status.
58
    if verbosity == "full":
59
        print(f'Setting up PyTorch plugin "{module_name}"...')
60
    elif verbosity == "brief":
61
        print(f'Setting up PyTorch plugin "{module_name}"... ', end="", flush=True)
62

63
    try:  # pylint: disable=too-many-nested-blocks
64
        # Make sure we can find the necessary compiler binaries.
65
        if os.name == "nt" and os.system("where cl.exe >nul 2>nul") != 0:
66
            compiler_bindir = _find_compiler_bindir()
67
            if compiler_bindir is None:
68
                raise RuntimeError(
69
                    f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".'
70
                )
71
            os.environ["PATH"] += ";" + compiler_bindir
72

73
        # Compile and load.
74
        verbose_build = verbosity == "full"
75

76
        # Incremental build md5sum trickery.  Copies all the input source files
77
        # into a cached build directory under a combined md5 digest of the input
78
        # source files.  Copying is done only if the combined digest has changed.
79
        # This keeps input file timestamps and filenames the same as in previous
80
        # extension builds, allowing for fast incremental rebuilds.
81
        #
82
        # This optimization is done only in case all the source files reside in
83
        # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
84
        # environment variable is set (we take this as a signal that the user
85
        # actually cares about this.)
86
        source_dirs_set = set(os.path.dirname(source) for source in sources)
87
        if len(source_dirs_set) == 1 and ("TORCH_EXTENSIONS_DIR" in os.environ):
88
            all_source_files = sorted(
89
                list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())
90
            )
91

92
            # Compute a combined hash digest for all source files in the same
93
            # custom op directory (usually .cu, .cpp, .py and .h files).
94
            hash_md5 = hashlib.md5()
95
            for src in all_source_files:
96
                with open(src, "rb") as f:
97
                    hash_md5.update(f.read())
98
            build_dir = torch.utils.cpp_extension._get_build_directory(
99
                module_name, verbose=verbose_build
100
            )  # pylint: disable=protected-access
101
            digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
102

103
            if not os.path.isdir(digest_build_dir):
104
                os.makedirs(digest_build_dir, exist_ok=True)
105
                baton = FileBaton(os.path.join(digest_build_dir, "lock"))
106
                if baton.try_acquire():
107
                    try:
108
                        for src in all_source_files:
109
                            shutil.copyfile(
110
                                src,
111
                                os.path.join(digest_build_dir, os.path.basename(src)),
112
                            )
113
                    finally:
114
                        baton.release()
115
                else:
116
                    # Someone else is copying source files under the digest dir,
117
                    # wait until done and continue.
118
                    baton.wait()
119
            digest_sources = [
120
                os.path.join(digest_build_dir, os.path.basename(x)) for x in sources
121
            ]
122
            torch.utils.cpp_extension.load(
123
                name=module_name,
124
                build_directory=build_dir,
125
                verbose=verbose_build,
126
                sources=digest_sources,
127
                **build_kwargs,
128
            )
129
        else:
130
            torch.utils.cpp_extension.load(
131
                name=module_name, verbose=verbose_build, sources=sources, **build_kwargs
132
            )
133
        module = importlib.import_module(module_name)
134

135
    except Exception:
136
        if verbosity == "brief":
137
            print("Failed!")
138
        raise
139

140
    # Print status and add to cache.
141
    if verbosity == "full":
142
        print(f'Done setting up PyTorch plugin "{module_name}".')
143
    elif verbosity == "brief":
144
        print("Done.")
145
    _cached_plugins[module_name] = module
146
    return module
147

148

149
# ----------------------------------------------------------------------------
150

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.