BasicSR
/
setup.py
172 строки · 5.4 Кб
1#!/usr/bin/env python
2
3from setuptools import find_packages, setup4
5import os6import subprocess7import time8
9version_file = 'basicsr/version.py'10
11
12def readme():13with open('README.md', encoding='utf-8') as f:14content = f.read()15return content16
17
18def get_git_hash():19
20def _minimal_ext_cmd(cmd):21# construct minimal environment22env = {}23for k in ['SYSTEMROOT', 'PATH', 'HOME']:24v = os.environ.get(k)25if v is not None:26env[k] = v27# LANGUAGE is used on win3228env['LANGUAGE'] = 'C'29env['LANG'] = 'C'30env['LC_ALL'] = 'C'31out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]32return out33
34try:35out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])36sha = out.strip().decode('ascii')37except OSError:38sha = 'unknown'39
40return sha41
42
43def get_hash():44if os.path.exists('.git'):45sha = get_git_hash()[:7]46# currently ignore this47# elif os.path.exists(version_file):48# try:49# from basicsr.version import __version__50# sha = __version__.split('+')[-1]51# except ImportError:52# raise ImportError('Unable to get git version')53else:54sha = 'unknown'55
56return sha57
58
59def write_version_py():60content = """# GENERATED VERSION FILE61# TIME: {}
62__version__ = '{}'
63__gitsha__ = '{}'
64version_info = ({})
65"""
66sha = get_hash()67with open('VERSION', 'r') as f:68SHORT_VERSION = f.read().strip()69VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])70
71version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)72with open(version_file, 'w') as f:73f.write(version_file_str)74
75
76def get_version():77with open(version_file, 'r') as f:78exec(compile(f.read(), version_file, 'exec'))79return locals()['__version__']80
81
82def make_cuda_ext(name, module, sources, sources_cuda=None):83if sources_cuda is None:84sources_cuda = []85define_macros = []86extra_compile_args = {'cxx': []}87
88if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':89define_macros += [('WITH_CUDA', None)]90extension = CUDAExtension91extra_compile_args['nvcc'] = [92'-D__CUDA_NO_HALF_OPERATORS__',93'-D__CUDA_NO_HALF_CONVERSIONS__',94'-D__CUDA_NO_HALF2_OPERATORS__',95]96sources += sources_cuda97else:98print(f'Compiling {name} without CUDA')99extension = CppExtension100
101return extension(102name=f'{module}.{name}',103sources=[os.path.join(*module.split('.'), p) for p in sources],104define_macros=define_macros,105extra_compile_args=extra_compile_args)106
107
108def get_requirements(filename='requirements.txt'):109here = os.path.dirname(os.path.realpath(__file__))110with open(os.path.join(here, filename), 'r') as f:111requires = [line.replace('\n', '') for line in f.readlines()]112return requires113
114
115if __name__ == '__main__':116cuda_ext = os.getenv('BASICSR_EXT') # whether compile cuda ext117if cuda_ext == 'True':118try:119import torch120from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension121except ImportError:122raise ImportError('Unable to import torch - torch is needed to build cuda extensions')123
124ext_modules = [125make_cuda_ext(126name='deform_conv_ext',127module='basicsr.ops.dcn',128sources=['src/deform_conv_ext.cpp'],129sources_cuda=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']),130make_cuda_ext(131name='fused_act_ext',132module='basicsr.ops.fused_act',133sources=['src/fused_bias_act.cpp'],134sources_cuda=['src/fused_bias_act_kernel.cu']),135make_cuda_ext(136name='upfirdn2d_ext',137module='basicsr.ops.upfirdn2d',138sources=['src/upfirdn2d.cpp'],139sources_cuda=['src/upfirdn2d_kernel.cu']),140]141setup_kwargs = dict(cmdclass={'build_ext': BuildExtension})142else:143ext_modules = []144setup_kwargs = dict()145
146write_version_py()147setup(148name='basicsr',149version=get_version(),150description='Open Source Image and Video Super-Resolution Toolbox',151long_description=readme(),152long_description_content_type='text/markdown',153author='Xintao Wang',154author_email='xintao.wang@outlook.com',155keywords='computer vision, restoration, super resolution',156url='https://github.com/xinntao/BasicSR',157include_package_data=True,158packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),159classifiers=[160'Development Status :: 4 - Beta',161'License :: OSI Approved :: Apache Software License',162'Operating System :: OS Independent',163'Programming Language :: Python :: 3',164'Programming Language :: Python :: 3.7',165'Programming Language :: Python :: 3.8',166],167license='Apache License 2.0',168setup_requires=['cython', 'numpy', 'torch'],169install_requires=get_requirements(),170ext_modules=ext_modules,171zip_safe=False,172**setup_kwargs)173