BasicSR

Форк
0
/
publish_models.py 
63 строки · 2.6 Кб
1
import glob
2
import subprocess
3
import torch
4
from os import path as osp
5
from torch.serialization import _is_zipfile, _open_file_like
6

7

8
def update_sha(paths):
9
    print('# Update sha ...')
10
    for idx, path in enumerate(paths):
11
        print(f'{idx+1:03d}: Processing {path}')
12
        net = torch.load(path, map_location=torch.device('cpu'))
13
        basename = osp.basename(path)
14
        if 'params' not in net and 'params_ema' not in net:
15
            user_response = input(f'WARN: Model {basename} does not have "params"/"params_ema" key. '
16
                                  'Do you still want to continue? Y/N\n')
17
            if user_response.lower() == 'y':
18
                pass
19
            elif user_response.lower() == 'n':
20
                raise ValueError('Please modify..')
21
            else:
22
                raise ValueError('Wrong input. Only accepts Y/N.')
23

24
        if '-' in basename:
25
            # check whether the sha is the latest
26
            old_sha = basename.split('-')[1].split('.')[0]
27
            new_sha = subprocess.check_output(['sha256sum', path]).decode()[:8]
28
            if old_sha != new_sha:
29
                final_file = path.split('-')[0] + f'-{new_sha}.pth'
30
                print(f'\tSave from {path} to {final_file}')
31
                subprocess.Popen(['mv', path, final_file])
32
        else:
33
            sha = subprocess.check_output(['sha256sum', path]).decode()[:8]
34
            final_file = path.split('.pth')[0] + f'-{sha}.pth'
35
            print(f'\tSave from {path} to {final_file}')
36
            subprocess.Popen(['mv', path, final_file])
37

38

39
def convert_to_backward_compatible_models(paths):
40
    """Convert to backward compatible pth files.
41

42
    PyTorch 1.6 uses a updated version of torch.save. In order to be compatible
43
    with previous PyTorch version, save it with
44
    _use_new_zipfile_serialization=False.
45
    """
46
    print('# Convert to backward compatible pth files ...')
47
    for idx, path in enumerate(paths):
48
        print(f'{idx+1:03d}: Processing {path}')
49
        flag_need_conversion = False
50
        with _open_file_like(path, 'rb') as opened_file:
51
            if _is_zipfile(opened_file):
52
                flag_need_conversion = True
53

54
        if flag_need_conversion:
55
            net = torch.load(path, map_location=torch.device('cpu'))
56
            print('\tConverting to compatible pth file...')
57
            torch.save(net, path, _use_new_zipfile_serialization=False)
58

59

60
if __name__ == '__main__':
61
    paths = glob.glob('experiments/pretrained_models/*.pth') + glob.glob('experiments/pretrained_models/**/*.pth')
62
    convert_to_backward_compatible_models(paths)
63
    update_sha(paths)
64

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.