BasicSR

Форк
0
/
train.py 
215 строк · 9.4 Кб
1
import datetime
2
import logging
3
import math
4
import time
5
import torch
6
from os import path as osp
7

8
from basicsr.data import build_dataloader, build_dataset
9
from basicsr.data.data_sampler import EnlargedSampler
10
from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
11
from basicsr.models import build_model
12
from basicsr.utils import (AvgTimer, MessageLogger, check_resume, get_env_info, get_root_logger, get_time_str,
13
                           init_tb_logger, init_wandb_logger, make_exp_dirs, mkdir_and_rename, scandir)
14
from basicsr.utils.options import copy_opt_file, dict2str, parse_options
15

16

17
def init_tb_loggers(opt):
18
    # initialize wandb logger before tensorboard logger to allow proper sync
19
    if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project')
20
                                                     is not None) and ('debug' not in opt['name']):
21
        assert opt['logger'].get('use_tb_logger') is True, ('should turn on tensorboard when using wandb')
22
        init_wandb_logger(opt)
23
    tb_logger = None
24
    if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']:
25
        tb_logger = init_tb_logger(log_dir=osp.join(opt['root_path'], 'tb_logger', opt['name']))
26
    return tb_logger
27

28

29
def create_train_val_dataloader(opt, logger):
30
    # create train and val dataloaders
31
    train_loader, val_loaders = None, []
32
    for phase, dataset_opt in opt['datasets'].items():
33
        if phase == 'train':
34
            dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1)
35
            train_set = build_dataset(dataset_opt)
36
            train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio)
37
            train_loader = build_dataloader(
38
                train_set,
39
                dataset_opt,
40
                num_gpu=opt['num_gpu'],
41
                dist=opt['dist'],
42
                sampler=train_sampler,
43
                seed=opt['manual_seed'])
44

45
            num_iter_per_epoch = math.ceil(
46
                len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size']))
47
            total_iters = int(opt['train']['total_iter'])
48
            total_epochs = math.ceil(total_iters / (num_iter_per_epoch))
49
            logger.info('Training statistics:'
50
                        f'\n\tNumber of train images: {len(train_set)}'
51
                        f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}'
52
                        f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}'
53
                        f'\n\tWorld size (gpu number): {opt["world_size"]}'
54
                        f'\n\tRequire iter number per epoch: {num_iter_per_epoch}'
55
                        f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.')
56
        elif phase.split('_')[0] == 'val':
57
            val_set = build_dataset(dataset_opt)
58
            val_loader = build_dataloader(
59
                val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed'])
60
            logger.info(f'Number of val images/folders in {dataset_opt["name"]}: {len(val_set)}')
61
            val_loaders.append(val_loader)
62
        else:
63
            raise ValueError(f'Dataset phase {phase} is not recognized.')
64

65
    return train_loader, train_sampler, val_loaders, total_epochs, total_iters
66

67

68
def load_resume_state(opt):
69
    resume_state_path = None
70
    if opt['auto_resume']:
71
        state_path = osp.join('experiments', opt['name'], 'training_states')
72
        if osp.isdir(state_path):
73
            states = list(scandir(state_path, suffix='state', recursive=False, full_path=False))
74
            if len(states) != 0:
75
                states = [float(v.split('.state')[0]) for v in states]
76
                resume_state_path = osp.join(state_path, f'{max(states):.0f}.state')
77
                opt['path']['resume_state'] = resume_state_path
78
    else:
79
        if opt['path'].get('resume_state'):
80
            resume_state_path = opt['path']['resume_state']
81

82
    if resume_state_path is None:
83
        resume_state = None
84
    else:
85
        device_id = torch.cuda.current_device()
86
        resume_state = torch.load(resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id))
87
        check_resume(opt, resume_state['iter'])
88
    return resume_state
89

90

91
def train_pipeline(root_path):
92
    # parse options, set distributed setting, set random seed
93
    opt, args = parse_options(root_path, is_train=True)
94
    opt['root_path'] = root_path
95

96
    torch.backends.cudnn.benchmark = True
97
    # torch.backends.cudnn.deterministic = True
98

99
    # load resume states if necessary
100
    resume_state = load_resume_state(opt)
101
    # mkdir for experiments and logger
102
    if resume_state is None:
103
        make_exp_dirs(opt)
104
        if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name'] and opt['rank'] == 0:
105
            mkdir_and_rename(osp.join(opt['root_path'], 'tb_logger', opt['name']))
106

107
    # copy the yml file to the experiment root
108
    copy_opt_file(args.opt, opt['path']['experiments_root'])
109

110
    # WARNING: should not use get_root_logger in the above codes, including the called functions
111
    # Otherwise the logger will not be properly initialized
112
    log_file = osp.join(opt['path']['log'], f"train_{opt['name']}_{get_time_str()}.log")
113
    logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
114
    logger.info(get_env_info())
115
    logger.info(dict2str(opt))
116
    # initialize wandb and tb loggers
117
    tb_logger = init_tb_loggers(opt)
118

119
    # create train and validation dataloaders
120
    result = create_train_val_dataloader(opt, logger)
121
    train_loader, train_sampler, val_loaders, total_epochs, total_iters = result
122

123
    # create model
124
    model = build_model(opt)
125
    if resume_state:  # resume training
126
        model.resume_training(resume_state)  # handle optimizers and schedulers
127
        logger.info(f"Resuming training from epoch: {resume_state['epoch']}, iter: {resume_state['iter']}.")
128
        start_epoch = resume_state['epoch']
129
        current_iter = resume_state['iter']
130
    else:
131
        start_epoch = 0
132
        current_iter = 0
133

134
    # create message logger (formatted outputs)
135
    msg_logger = MessageLogger(opt, current_iter, tb_logger)
136

137
    # dataloader prefetcher
138
    prefetch_mode = opt['datasets']['train'].get('prefetch_mode')
139
    if prefetch_mode is None or prefetch_mode == 'cpu':
140
        prefetcher = CPUPrefetcher(train_loader)
141
    elif prefetch_mode == 'cuda':
142
        prefetcher = CUDAPrefetcher(train_loader, opt)
143
        logger.info(f'Use {prefetch_mode} prefetch dataloader')
144
        if opt['datasets']['train'].get('pin_memory') is not True:
145
            raise ValueError('Please set pin_memory=True for CUDAPrefetcher.')
146
    else:
147
        raise ValueError(f"Wrong prefetch_mode {prefetch_mode}. Supported ones are: None, 'cuda', 'cpu'.")
148

149
    # training
150
    logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter}')
151
    data_timer, iter_timer = AvgTimer(), AvgTimer()
152
    start_time = time.time()
153

154
    for epoch in range(start_epoch, total_epochs + 1):
155
        train_sampler.set_epoch(epoch)
156
        prefetcher.reset()
157
        train_data = prefetcher.next()
158

159
        while train_data is not None:
160
            data_timer.record()
161

162
            current_iter += 1
163
            if current_iter > total_iters:
164
                break
165
            # update learning rate
166
            model.update_learning_rate(current_iter, warmup_iter=opt['train'].get('warmup_iter', -1))
167
            # training
168
            model.feed_data(train_data)
169
            model.optimize_parameters(current_iter)
170
            iter_timer.record()
171
            if current_iter == 1:
172
                # reset start time in msg_logger for more accurate eta_time
173
                # not work in resume mode
174
                msg_logger.reset_start_time()
175
            # log
176
            if current_iter % opt['logger']['print_freq'] == 0:
177
                log_vars = {'epoch': epoch, 'iter': current_iter}
178
                log_vars.update({'lrs': model.get_current_learning_rate()})
179
                log_vars.update({'time': iter_timer.get_avg_time(), 'data_time': data_timer.get_avg_time()})
180
                log_vars.update(model.get_current_log())
181
                msg_logger(log_vars)
182

183
            # save models and training states
184
            if current_iter % opt['logger']['save_checkpoint_freq'] == 0:
185
                logger.info('Saving models and training states.')
186
                model.save(epoch, current_iter)
187

188
            # validation
189
            if opt.get('val') is not None and (current_iter % opt['val']['val_freq'] == 0):
190
                if len(val_loaders) > 1:
191
                    logger.warning('Multiple validation datasets are *only* supported by SRModel.')
192
                for val_loader in val_loaders:
193
                    model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
194

195
            data_timer.start()
196
            iter_timer.start()
197
            train_data = prefetcher.next()
198
        # end of iter
199

200
    # end of epoch
201

202
    consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time)))
203
    logger.info(f'End of training. Time consumed: {consumed_time}')
204
    logger.info('Save the latest model.')
205
    model.save(epoch=-1, current_iter=-1)  # -1 stands for the latest
206
    if opt.get('val') is not None:
207
        for val_loader in val_loaders:
208
            model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
209
    if tb_logger:
210
        tb_logger.close()
211

212

213
if __name__ == '__main__':
214
    root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
215
    train_pipeline(root_path)
216

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.