4
from os import path as osp
5
from torch.serialization import _is_zipfile, _open_file_like
9
print('# Update sha ...')
10
for idx, path in enumerate(paths):
11
print(f'{idx+1:03d}: Processing {path}')
12
net = torch.load(path, map_location=torch.device('cpu'))
13
basename = osp.basename(path)
14
if 'params' not in net and 'params_ema' not in net:
15
user_response = input(f'WARN: Model {basename} does not have "params"/"params_ema" key. '
16
'Do you still want to continue? Y/N\n')
17
if user_response.lower() == 'y':
19
elif user_response.lower() == 'n':
20
raise ValueError('Please modify..')
22
raise ValueError('Wrong input. Only accepts Y/N.')
26
old_sha = basename.split('-')[1].split('.')[0]
27
new_sha = subprocess.check_output(['sha256sum', path]).decode()[:8]
28
if old_sha != new_sha:
29
final_file = path.split('-')[0] + f'-{new_sha}.pth'
30
print(f'\tSave from {path} to {final_file}')
31
subprocess.Popen(['mv', path, final_file])
33
sha = subprocess.check_output(['sha256sum', path]).decode()[:8]
34
final_file = path.split('.pth')[0] + f'-{sha}.pth'
35
print(f'\tSave from {path} to {final_file}')
36
subprocess.Popen(['mv', path, final_file])
39
def convert_to_backward_compatible_models(paths):
40
"""Convert to backward compatible pth files.
42
PyTorch 1.6 uses a updated version of torch.save. In order to be compatible
43
with previous PyTorch version, save it with
44
_use_new_zipfile_serialization=False.
46
print('# Convert to backward compatible pth files ...')
47
for idx, path in enumerate(paths):
48
print(f'{idx+1:03d}: Processing {path}')
49
flag_need_conversion = False
50
with _open_file_like(path, 'rb') as opened_file:
51
if _is_zipfile(opened_file):
52
flag_need_conversion = True
54
if flag_need_conversion:
55
net = torch.load(path, map_location=torch.device('cpu'))
56
print('\tConverting to compatible pth file...')
57
torch.save(net, path, _use_new_zipfile_serialization=False)
60
if __name__ == '__main__':
61
paths = glob.glob('experiments/pretrained_models/*.pth') + glob.glob('experiments/pretrained_models/**/*.pth')
62
convert_to_backward_compatible_models(paths)