StyleFeatureEditor

Форк
0
169 строк · 7.1 Кб
1
# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
#
3
# This work is made available under the Nvidia Source Code License-NC.
4
# To view a copy of this license, visit
5
# https://nvlabs.github.io/stylegan2/license.html
6

7
"""TensorFlow custom ops builder.
8
"""
9

10
import os
11
import re
12
import uuid
13
import hashlib
14
import tempfile
15
import shutil
16
import tensorflow as tf
17
from tensorflow.python.client import device_lib # pylint: disable=no-name-in-module
18

19
#----------------------------------------------------------------------------
20
# Global options.
21

22
cuda_cache_path = os.path.join(os.path.dirname(__file__), '_cudacache')
23
cuda_cache_version_tag = 'v1'
24
do_not_hash_included_headers = False # Speed up compilation by assuming that headers included by the CUDA code never change. Unsafe!
25
verbose = True # Print status messages to stdout.
26

27
compiler_bindir_search_path = [
28
    'C:/Program Files (x86)/Microsoft Visual Studio/2017/Community/VC/Tools/MSVC/14.14.26428/bin/Hostx64/x64',
29
    'C:/Program Files (x86)/Microsoft Visual Studio/2019/Community/VC/Tools/MSVC/14.23.28105/bin/Hostx64/x64',
30
    'C:/Program Files (x86)/Microsoft Visual Studio 14.0/vc/bin',
31
]
32

33
#----------------------------------------------------------------------------
34
# Internal helper funcs.
35

36
def _find_compiler_bindir():
37
    for compiler_path in compiler_bindir_search_path:
38
        if os.path.isdir(compiler_path):
39
            return compiler_path
40
    return None
41

42
def _get_compute_cap(device):
43
    caps_str = device.physical_device_desc
44
    m = re.search('compute capability: (\\d+).(\\d+)', caps_str)
45
    major = m.group(1)
46
    minor = m.group(2)
47
    return (major, minor)
48

49
def _get_cuda_gpu_arch_string():
50
    gpus = [x for x in device_lib.list_local_devices() if x.device_type == 'GPU']
51
    if len(gpus) == 0:
52
        raise RuntimeError('No GPU devices found')
53
    (major, minor) = _get_compute_cap(gpus[0])
54
    return 'sm_%s%s' % (major, minor)
55

56
def _run_cmd(cmd):
57
    with os.popen(cmd) as pipe:
58
        output = pipe.read()
59
        status = pipe.close()
60
    if status is not None:
61
        raise RuntimeError('NVCC returned an error. See below for full command line and output log:\n\n%s\n\n%s' % (cmd, output))
62

63
def _prepare_nvcc_cli(opts):
64
    cmd = 'nvcc ' + opts.strip()
65
    cmd += ' --disable-warnings'
66
    cmd += ' --include-path "%s"' % tf.sysconfig.get_include()
67
    cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'protobuf_archive', 'src')
68
    cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'com_google_absl')
69
    cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'eigen_archive')
70

71
    compiler_bindir = _find_compiler_bindir()
72
    if compiler_bindir is None:
73
        # Require that _find_compiler_bindir succeeds on Windows.  Allow
74
        # nvcc to use whatever is the default on Linux.
75
        if os.name == 'nt':
76
            raise RuntimeError('Could not find MSVC/GCC/CLANG installation on this computer. Check compiler_bindir_search_path list in "%s".' % __file__)
77
    else:
78
        cmd += ' --compiler-bindir "%s"' % compiler_bindir
79
    cmd += ' 2>&1'
80
    return cmd
81

82
#----------------------------------------------------------------------------
83
# Main entry point.
84

85
_plugin_cache = dict()
86

87
def get_plugin(cuda_file):
88
    cuda_file_base = os.path.basename(cuda_file)
89
    cuda_file_name, cuda_file_ext = os.path.splitext(cuda_file_base)
90

91
    # Already in cache?
92
    if cuda_file in _plugin_cache:
93
        return _plugin_cache[cuda_file]
94

95
    # Setup plugin.
96
    if verbose:
97
        print('Setting up TensorFlow plugin "%s": ' % cuda_file_base, end='', flush=True)
98
    try:
99
        # Hash CUDA source.
100
        md5 = hashlib.md5()
101
        with open(cuda_file, 'rb') as f:
102
            md5.update(f.read())
103
        md5.update(b'\n')
104

105
        # Hash headers included by the CUDA code by running it through the preprocessor.
106
        if not do_not_hash_included_headers:
107
            if verbose:
108
                print('Preprocessing... ', end='', flush=True)
109
            with tempfile.TemporaryDirectory() as tmp_dir:
110
                tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + cuda_file_ext)
111
                _run_cmd(_prepare_nvcc_cli('"%s" --preprocess -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir)))
112
                with open(tmp_file, 'rb') as f:
113
                    bad_file_str = ('"' + cuda_file.replace('\\', '/') + '"').encode('utf-8') # __FILE__ in error check macros
114
                    good_file_str = ('"' + cuda_file_base + '"').encode('utf-8')
115
                    for ln in f:
116
                        if not ln.startswith(b'# ') and not ln.startswith(b'#line '): # ignore line number pragmas
117
                            ln = ln.replace(bad_file_str, good_file_str)
118
                            md5.update(ln)
119
                    md5.update(b'\n')
120

121
        # Select compiler options.
122
        compile_opts = ''
123
        if os.name == 'nt':
124
            compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.lib')
125
        elif os.name == 'posix':
126
            compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.so')
127
            compile_opts += ' --compiler-options \'-fPIC -D_GLIBCXX_USE_CXX11_ABI=0\''
128
        else:
129
            assert False # not Windows or Linux, w00t?
130
        compile_opts += ' --gpu-architecture=%s' % _get_cuda_gpu_arch_string()
131
        compile_opts += ' --use_fast_math'
132
        nvcc_cmd = _prepare_nvcc_cli(compile_opts)
133

134
        # Hash build configuration.
135
        md5.update(('nvcc_cmd: ' + nvcc_cmd).encode('utf-8') + b'\n')
136
        md5.update(('tf.VERSION: ' + tf.VERSION).encode('utf-8') + b'\n')
137
        md5.update(('cuda_cache_version_tag: ' + cuda_cache_version_tag).encode('utf-8') + b'\n')
138

139
        # Compile if not already compiled.
140
        bin_file_ext = '.dll' if os.name == 'nt' else '.so'
141
        bin_file = os.path.join(cuda_cache_path, cuda_file_name + '_' + md5.hexdigest() + bin_file_ext)
142
        if not os.path.isfile(bin_file):
143
            if verbose:
144
                print('Compiling... ', end='', flush=True)
145
            with tempfile.TemporaryDirectory() as tmp_dir:
146
                tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + bin_file_ext)
147
                _run_cmd(nvcc_cmd + ' "%s" --shared -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir))
148
                os.makedirs(cuda_cache_path, exist_ok=True)
149
                intermediate_file = os.path.join(cuda_cache_path, cuda_file_name + '_' + uuid.uuid4().hex + '_tmp' + bin_file_ext)
150
                shutil.copyfile(tmp_file, intermediate_file)
151
                os.rename(intermediate_file, bin_file) # atomic
152

153
        # Load.
154
        if verbose:
155
            print('Loading... ', end='', flush=True)
156
        plugin = tf.load_op_library(bin_file)
157

158
        # Add to cache.
159
        _plugin_cache[cuda_file] = plugin
160
        if verbose:
161
            print('Done.', flush=True)
162
        return plugin
163

164
    except:
165
        if verbose:
166
            print('Failed!', flush=True)
167
        raise
168

169
#----------------------------------------------------------------------------
170

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.