pytorch-image-models

Форк
0
/
clean_checkpoint.py 
115 строк · 4.1 Кб
1
#!/usr/bin/env python3
2
""" Checkpoint Cleaning Script
3

4
Takes training checkpoints with GPU tensors, optimizer state, extra dict keys, etc.
5
and outputs a CPU  tensor checkpoint with only the `state_dict` along with SHA256
6
calculation for model zoo compatibility.
7

8
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
9
"""
10
import torch
11
import argparse
12
import os
13
import hashlib
14
import shutil
15
import tempfile
16
from timm.models import load_state_dict
17
try:
18
    import safetensors.torch
19
    _has_safetensors = True
20
except ImportError:
21
    _has_safetensors = False
22

23
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
24
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
25
                    help='path to latest checkpoint (default: none)')
26
parser.add_argument('--output', default='', type=str, metavar='PATH',
27
                    help='output path')
28
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
29
                    help='use ema version of weights if present')
30
parser.add_argument('--no-hash', dest='no_hash', action='store_true',
31
                    help='no hash in output filename')
32
parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true',
33
                    help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint')
34
parser.add_argument('--safetensors', action='store_true',
35
                    help='Save weights using safetensors instead of the default torch way (pickle).')
36

37

38
def main():
39
    args = parser.parse_args()
40

41
    if os.path.exists(args.output):
42
        print("Error: Output filename ({}) already exists.".format(args.output))
43
        exit(1)
44

45
    clean_checkpoint(
46
        args.checkpoint,
47
        args.output,
48
        not args.no_use_ema,
49
        args.no_hash,
50
        args.clean_aux_bn,
51
        safe_serialization=args.safetensors,
52
    )
53

54

55
def clean_checkpoint(
56
        checkpoint,
57
        output,
58
        use_ema=True,
59
        no_hash=False,
60
        clean_aux_bn=False,
61
        safe_serialization: bool=False,
62
):
63
    # Load an existing checkpoint to CPU, strip everything but the state_dict and re-save
64
    if checkpoint and os.path.isfile(checkpoint):
65
        print("=> Loading checkpoint '{}'".format(checkpoint))
66
        state_dict = load_state_dict(checkpoint, use_ema=use_ema)
67
        new_state_dict = {}
68
        for k, v in state_dict.items():
69
            if clean_aux_bn and 'aux_bn' in k:
70
                # If all aux_bn keys are removed, the SplitBN layers will end up as normal and
71
                # load with the unmodified model using BatchNorm2d.
72
                continue
73
            name = k[7:] if k.startswith('module.') else k
74
            new_state_dict[name] = v
75
        print("=> Loaded state_dict from '{}'".format(checkpoint))
76

77
        ext = ''
78
        if output:
79
            checkpoint_root, checkpoint_base = os.path.split(output)
80
            checkpoint_base, ext = os.path.splitext(checkpoint_base)
81
        else:
82
            checkpoint_root = ''
83
            checkpoint_base = os.path.split(checkpoint)[1]
84
            checkpoint_base = os.path.splitext(checkpoint_base)[0]
85

86
        temp_filename = '__' + checkpoint_base
87
        if safe_serialization:
88
            assert _has_safetensors, "`pip install safetensors` to use .safetensors"
89
            safetensors.torch.save_file(new_state_dict, temp_filename)
90
        else:
91
            torch.save(new_state_dict, temp_filename)
92

93
        with open(temp_filename, 'rb') as f:
94
            sha_hash = hashlib.sha256(f.read()).hexdigest()
95

96
        if ext:
97
            final_ext = ext
98
        else:
99
            final_ext = ('.safetensors' if safe_serialization else '.pth')
100

101
        if no_hash:
102
            final_filename = checkpoint_base + final_ext
103
        else:
104
            final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + final_ext
105

106
        shutil.move(temp_filename, os.path.join(checkpoint_root, final_filename))
107
        print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash))
108
        return final_filename
109
    else:
110
        print("Error: Checkpoint ({}) doesn't exist".format(checkpoint))
111
        return ''
112

113

114
if __name__ == '__main__':
115
    main()
116

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.