BasicSR

Форк
0
/
options.py 
218 строк · 6.8 Кб
1
import argparse
2
import os
3
import random
4
import torch
5
import yaml
6
from collections import OrderedDict
7
from os import path as osp
8

9
from basicsr.utils import set_random_seed
10
from basicsr.utils.dist_util import get_dist_info, init_dist, master_only
11

12

13
def ordered_yaml():
14
    """Support OrderedDict for yaml.
15

16
    Returns:
17
        tuple: yaml Loader and Dumper.
18
    """
19
    try:
20
        from yaml import CDumper as Dumper
21
        from yaml import CLoader as Loader
22
    except ImportError:
23
        from yaml import Dumper, Loader
24

25
    _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
26

27
    def dict_representer(dumper, data):
28
        return dumper.represent_dict(data.items())
29

30
    def dict_constructor(loader, node):
31
        return OrderedDict(loader.construct_pairs(node))
32

33
    Dumper.add_representer(OrderedDict, dict_representer)
34
    Loader.add_constructor(_mapping_tag, dict_constructor)
35
    return Loader, Dumper
36

37

38
def yaml_load(f):
39
    """Load yaml file or string.
40

41
    Args:
42
        f (str): File path or a python string.
43

44
    Returns:
45
        dict: Loaded dict.
46
    """
47
    if os.path.isfile(f):
48
        with open(f, 'r') as f:
49
            return yaml.load(f, Loader=ordered_yaml()[0])
50
    else:
51
        return yaml.load(f, Loader=ordered_yaml()[0])
52

53

54
def dict2str(opt, indent_level=1):
55
    """dict to string for printing options.
56

57
    Args:
58
        opt (dict): Option dict.
59
        indent_level (int): Indent level. Default: 1.
60

61
    Return:
62
        (str): Option string for printing.
63
    """
64
    msg = '\n'
65
    for k, v in opt.items():
66
        if isinstance(v, dict):
67
            msg += ' ' * (indent_level * 2) + k + ':['
68
            msg += dict2str(v, indent_level + 1)
69
            msg += ' ' * (indent_level * 2) + ']\n'
70
        else:
71
            msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
72
    return msg
73

74

75
def _postprocess_yml_value(value):
76
    # None
77
    if value == '~' or value.lower() == 'none':
78
        return None
79
    # bool
80
    if value.lower() == 'true':
81
        return True
82
    elif value.lower() == 'false':
83
        return False
84
    # !!float number
85
    if value.startswith('!!float'):
86
        return float(value.replace('!!float', ''))
87
    # number
88
    if value.isdigit():
89
        return int(value)
90
    elif value.replace('.', '', 1).isdigit() and value.count('.') < 2:
91
        return float(value)
92
    # list
93
    if value.startswith('['):
94
        return eval(value)
95
    # str
96
    return value
97

98

99
def parse_options(root_path, is_train=True):
100
    parser = argparse.ArgumentParser()
101
    parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.')
102
    parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
103
    parser.add_argument('--auto_resume', action='store_true')
104
    parser.add_argument('--debug', action='store_true')
105
    parser.add_argument('--local_rank', type=int, default=0)
106
    parser.add_argument(
107
        '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999')
108
    args = parser.parse_args()
109

110
    # parse yml to dict
111
    opt = yaml_load(args.opt)
112

113
    # distributed settings
114
    if args.launcher == 'none':
115
        opt['dist'] = False
116
        print('Disable distributed.', flush=True)
117
    else:
118
        opt['dist'] = True
119
        if args.launcher == 'slurm' and 'dist_params' in opt:
120
            init_dist(args.launcher, **opt['dist_params'])
121
        else:
122
            init_dist(args.launcher)
123
    opt['rank'], opt['world_size'] = get_dist_info()
124

125
    # random seed
126
    seed = opt.get('manual_seed')
127
    if seed is None:
128
        seed = random.randint(1, 10000)
129
        opt['manual_seed'] = seed
130
    set_random_seed(seed + opt['rank'])
131

132
    # force to update yml options
133
    if args.force_yml is not None:
134
        for entry in args.force_yml:
135
            # now do not support creating new keys
136
            keys, value = entry.split('=')
137
            keys, value = keys.strip(), value.strip()
138
            value = _postprocess_yml_value(value)
139
            eval_str = 'opt'
140
            for key in keys.split(':'):
141
                eval_str += f'["{key}"]'
142
            eval_str += '=value'
143
            # using exec function
144
            exec(eval_str)
145

146
    opt['auto_resume'] = args.auto_resume
147
    opt['is_train'] = is_train
148

149
    # debug setting
150
    if args.debug and not opt['name'].startswith('debug'):
151
        opt['name'] = 'debug_' + opt['name']
152

153
    if opt['num_gpu'] == 'auto':
154
        opt['num_gpu'] = torch.cuda.device_count()
155

156
    # datasets
157
    for phase, dataset in opt['datasets'].items():
158
        # for multiple datasets, e.g., val_1, val_2; test_1, test_2
159
        phase = phase.split('_')[0]
160
        dataset['phase'] = phase
161
        if 'scale' in opt:
162
            dataset['scale'] = opt['scale']
163
        if dataset.get('dataroot_gt') is not None:
164
            dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
165
        if dataset.get('dataroot_lq') is not None:
166
            dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
167

168
    # paths
169
    for key, val in opt['path'].items():
170
        if (val is not None) and ('resume_state' in key or 'pretrain_network' in key):
171
            opt['path'][key] = osp.expanduser(val)
172

173
    if is_train:
174
        experiments_root = opt['path'].get('experiments_root')
175
        if experiments_root is None:
176
            experiments_root = osp.join(root_path, 'experiments')
177
        experiments_root = osp.join(experiments_root, opt['name'])
178

179
        opt['path']['experiments_root'] = experiments_root
180
        opt['path']['models'] = osp.join(experiments_root, 'models')
181
        opt['path']['training_states'] = osp.join(experiments_root, 'training_states')
182
        opt['path']['log'] = experiments_root
183
        opt['path']['visualization'] = osp.join(experiments_root, 'visualization')
184

185
        # change some options for debug mode
186
        if 'debug' in opt['name']:
187
            if 'val' in opt:
188
                opt['val']['val_freq'] = 8
189
            opt['logger']['print_freq'] = 1
190
            opt['logger']['save_checkpoint_freq'] = 8
191
    else:  # test
192
        results_root = opt['path'].get('results_root')
193
        if results_root is None:
194
            results_root = osp.join(root_path, 'results')
195
        results_root = osp.join(results_root, opt['name'])
196

197
        opt['path']['results_root'] = results_root
198
        opt['path']['log'] = results_root
199
        opt['path']['visualization'] = osp.join(results_root, 'visualization')
200

201
    return opt, args
202

203

204
@master_only
205
def copy_opt_file(opt_file, experiments_root):
206
    # copy the yml file to the experiment root
207
    import sys
208
    import time
209
    from shutil import copyfile
210
    cmd = ' '.join(sys.argv)
211
    filename = osp.join(experiments_root, osp.basename(opt_file))
212
    copyfile(opt_file, filename)
213

214
    with open(filename, 'r+') as f:
215
        lines = f.readlines()
216
        lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n')
217
        f.seek(0)
218
        f.writelines(lines)
219

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.