BasicSR

Форк
0
/
setup.py 
172 строки · 5.4 Кб
1
#!/usr/bin/env python
2

3
from setuptools import find_packages, setup
4

5
import os
6
import subprocess
7
import time
8

9
version_file = 'basicsr/version.py'
10

11

12
def readme():
13
    with open('README.md', encoding='utf-8') as f:
14
        content = f.read()
15
    return content
16

17

18
def get_git_hash():
19

20
    def _minimal_ext_cmd(cmd):
21
        # construct minimal environment
22
        env = {}
23
        for k in ['SYSTEMROOT', 'PATH', 'HOME']:
24
            v = os.environ.get(k)
25
            if v is not None:
26
                env[k] = v
27
        # LANGUAGE is used on win32
28
        env['LANGUAGE'] = 'C'
29
        env['LANG'] = 'C'
30
        env['LC_ALL'] = 'C'
31
        out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
32
        return out
33

34
    try:
35
        out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
36
        sha = out.strip().decode('ascii')
37
    except OSError:
38
        sha = 'unknown'
39

40
    return sha
41

42

43
def get_hash():
44
    if os.path.exists('.git'):
45
        sha = get_git_hash()[:7]
46
    # currently ignore this
47
    # elif os.path.exists(version_file):
48
    #     try:
49
    #         from basicsr.version import __version__
50
    #         sha = __version__.split('+')[-1]
51
    #     except ImportError:
52
    #         raise ImportError('Unable to get git version')
53
    else:
54
        sha = 'unknown'
55

56
    return sha
57

58

59
def write_version_py():
60
    content = """# GENERATED VERSION FILE
61
# TIME: {}
62
__version__ = '{}'
63
__gitsha__ = '{}'
64
version_info = ({})
65
"""
66
    sha = get_hash()
67
    with open('VERSION', 'r') as f:
68
        SHORT_VERSION = f.read().strip()
69
    VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
70

71
    version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
72
    with open(version_file, 'w') as f:
73
        f.write(version_file_str)
74

75

76
def get_version():
77
    with open(version_file, 'r') as f:
78
        exec(compile(f.read(), version_file, 'exec'))
79
    return locals()['__version__']
80

81

82
def make_cuda_ext(name, module, sources, sources_cuda=None):
83
    if sources_cuda is None:
84
        sources_cuda = []
85
    define_macros = []
86
    extra_compile_args = {'cxx': []}
87

88
    if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
89
        define_macros += [('WITH_CUDA', None)]
90
        extension = CUDAExtension
91
        extra_compile_args['nvcc'] = [
92
            '-D__CUDA_NO_HALF_OPERATORS__',
93
            '-D__CUDA_NO_HALF_CONVERSIONS__',
94
            '-D__CUDA_NO_HALF2_OPERATORS__',
95
        ]
96
        sources += sources_cuda
97
    else:
98
        print(f'Compiling {name} without CUDA')
99
        extension = CppExtension
100

101
    return extension(
102
        name=f'{module}.{name}',
103
        sources=[os.path.join(*module.split('.'), p) for p in sources],
104
        define_macros=define_macros,
105
        extra_compile_args=extra_compile_args)
106

107

108
def get_requirements(filename='requirements.txt'):
109
    here = os.path.dirname(os.path.realpath(__file__))
110
    with open(os.path.join(here, filename), 'r') as f:
111
        requires = [line.replace('\n', '') for line in f.readlines()]
112
    return requires
113

114

115
if __name__ == '__main__':
116
    cuda_ext = os.getenv('BASICSR_EXT')  # whether compile cuda ext
117
    if cuda_ext == 'True':
118
        try:
119
            import torch
120
            from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
121
        except ImportError:
122
            raise ImportError('Unable to import torch - torch is needed to build cuda extensions')
123

124
        ext_modules = [
125
            make_cuda_ext(
126
                name='deform_conv_ext',
127
                module='basicsr.ops.dcn',
128
                sources=['src/deform_conv_ext.cpp'],
129
                sources_cuda=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']),
130
            make_cuda_ext(
131
                name='fused_act_ext',
132
                module='basicsr.ops.fused_act',
133
                sources=['src/fused_bias_act.cpp'],
134
                sources_cuda=['src/fused_bias_act_kernel.cu']),
135
            make_cuda_ext(
136
                name='upfirdn2d_ext',
137
                module='basicsr.ops.upfirdn2d',
138
                sources=['src/upfirdn2d.cpp'],
139
                sources_cuda=['src/upfirdn2d_kernel.cu']),
140
        ]
141
        setup_kwargs = dict(cmdclass={'build_ext': BuildExtension})
142
    else:
143
        ext_modules = []
144
        setup_kwargs = dict()
145

146
    write_version_py()
147
    setup(
148
        name='basicsr',
149
        version=get_version(),
150
        description='Open Source Image and Video Super-Resolution Toolbox',
151
        long_description=readme(),
152
        long_description_content_type='text/markdown',
153
        author='Xintao Wang',
154
        author_email='xintao.wang@outlook.com',
155
        keywords='computer vision, restoration, super resolution',
156
        url='https://github.com/xinntao/BasicSR',
157
        include_package_data=True,
158
        packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),
159
        classifiers=[
160
            'Development Status :: 4 - Beta',
161
            'License :: OSI Approved :: Apache Software License',
162
            'Operating System :: OS Independent',
163
            'Programming Language :: Python :: 3',
164
            'Programming Language :: Python :: 3.7',
165
            'Programming Language :: Python :: 3.8',
166
        ],
167
        license='Apache License 2.0',
168
        setup_requires=['cython', 'numpy', 'torch'],
169
        install_requires=get_requirements(),
170
        ext_modules=ext_modules,
171
        zip_safe=False,
172
        **setup_kwargs)
173

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.