pytorch-image-models

Форк
0
/
avg_checkpoints.py 
152 строки · 5.8 Кб
1
#!/usr/bin/env python3
2
""" Checkpoint Averaging Script
3

4
This script averages all model weights for checkpoints in specified path that match
5
the specified filter wildcard. All checkpoints must be from the exact same model.
6

7
For any hope of decent results, the checkpoints should be from the same or child
8
(via resumes) training session. This can be viewed as similar to maintaining running
9
EMA (exponential moving average) of the model weights or performing SWA (stochastic
10
weight averaging), but post-training.
11

12
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
13
"""
14
import torch
15
import argparse
16
import os
17
import glob
18
import hashlib
19
from timm.models import load_state_dict
20
try:
21
    import safetensors.torch
22
    _has_safetensors = True
23
except ImportError:
24
    _has_safetensors = False
25

26
DEFAULT_OUTPUT = "./averaged.pth"
27
DEFAULT_SAFE_OUTPUT = "./averaged.safetensors"
28

29
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
30
parser.add_argument('--input', default='', type=str, metavar='PATH',
31
                    help='path to base input folder containing checkpoints')
32
parser.add_argument('--filter', default='*.pth.tar', type=str, metavar='WILDCARD',
33
                    help='checkpoint filter (path wildcard)')
34
parser.add_argument('--output', default=DEFAULT_OUTPUT, type=str, metavar='PATH',
35
                    help=f'Output filename. Defaults to {DEFAULT_SAFE_OUTPUT} when passing --safetensors.')
36
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
37
                    help='Force not using ema version of weights (if present)')
38
parser.add_argument('--no-sort', dest='no_sort', action='store_true',
39
                    help='Do not sort and select by checkpoint metric, also makes "n" argument irrelevant')
40
parser.add_argument('-n', type=int, default=10, metavar='N',
41
                    help='Number of checkpoints to average')
42
parser.add_argument('--safetensors', action='store_true',
43
                    help='Save weights using safetensors instead of the default torch way (pickle).')
44

45

46
def checkpoint_metric(checkpoint_path):
47
    if not checkpoint_path or not os.path.isfile(checkpoint_path):
48
        return {}
49
    print("=> Extracting metric from checkpoint '{}'".format(checkpoint_path))
50
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
51
    metric = None
52
    if 'metric' in checkpoint:
53
        metric = checkpoint['metric']
54
    elif 'metrics' in checkpoint and 'metric_name' in checkpoint:
55
        metrics = checkpoint['metrics']
56
        print(metrics)
57
        metric = metrics[checkpoint['metric_name']]
58
    return metric
59

60

61
def main():
62
    args = parser.parse_args()
63
    # by default use the EMA weights (if present)
64
    args.use_ema = not args.no_use_ema
65
    # by default sort by checkpoint metric (if present) and avg top n checkpoints
66
    args.sort = not args.no_sort
67

68
    if args.safetensors and args.output == DEFAULT_OUTPUT:
69
        # Default path changes if using safetensors
70
        args.output = DEFAULT_SAFE_OUTPUT
71

72
    output, output_ext = os.path.splitext(args.output)
73
    if not output_ext:
74
        output_ext = ('.safetensors' if args.safetensors else '.pth')
75
    output = output + output_ext
76

77
    if args.safetensors and not output_ext == ".safetensors":
78
        print(
79
            "Warning: saving weights as safetensors but output file extension is not "
80
            f"set to '.safetensors': {args.output}"
81
        )
82

83
    if os.path.exists(output):
84
        print("Error: Output filename ({}) already exists.".format(output))
85
        exit(1)
86

87
    pattern = args.input
88
    if not args.input.endswith(os.path.sep) and not args.filter.startswith(os.path.sep):
89
        pattern += os.path.sep
90
    pattern += args.filter
91
    checkpoints = glob.glob(pattern, recursive=True)
92

93
    if args.sort:
94
        checkpoint_metrics = []
95
        for c in checkpoints:
96
            metric = checkpoint_metric(c)
97
            if metric is not None:
98
                checkpoint_metrics.append((metric, c))
99
        checkpoint_metrics = list(sorted(checkpoint_metrics))
100
        checkpoint_metrics = checkpoint_metrics[-args.n:]
101
        if checkpoint_metrics:
102
            print("Selected checkpoints:")
103
            [print(m, c) for m, c in checkpoint_metrics]
104
        avg_checkpoints = [c for m, c in checkpoint_metrics]
105
    else:
106
        avg_checkpoints = checkpoints
107
        if avg_checkpoints:
108
            print("Selected checkpoints:")
109
            [print(c) for c in checkpoints]
110

111
    if not avg_checkpoints:
112
        print('Error: No checkpoints found to average.')
113
        exit(1)
114

115
    avg_state_dict = {}
116
    avg_counts = {}
117
    for c in avg_checkpoints:
118
        new_state_dict = load_state_dict(c, args.use_ema)
119
        if not new_state_dict:
120
            print(f"Error: Checkpoint ({c}) doesn't exist")
121
            continue
122
        for k, v in new_state_dict.items():
123
            if k not in avg_state_dict:
124
                avg_state_dict[k] = v.clone().to(dtype=torch.float64)
125
                avg_counts[k] = 1
126
            else:
127
                avg_state_dict[k] += v.to(dtype=torch.float64)
128
                avg_counts[k] += 1
129

130
    for k, v in avg_state_dict.items():
131
        v.div_(avg_counts[k])
132

133
    # float32 overflow seems unlikely based on weights seen to date, but who knows
134
    float32_info = torch.finfo(torch.float32)
135
    final_state_dict = {}
136
    for k, v in avg_state_dict.items():
137
        v = v.clamp(float32_info.min, float32_info.max)
138
        final_state_dict[k] = v.to(dtype=torch.float32)
139

140
    if args.safetensors:
141
        assert _has_safetensors, "`pip install safetensors` to use .safetensors"
142
        safetensors.torch.save_file(final_state_dict, output)
143
    else:
144
        torch.save(final_state_dict, output)
145

146
    with open(output, 'rb') as f:
147
        sha_hash = hashlib.sha256(f.read()).hexdigest()
148
    print(f"=> Saved state_dict to '{output}, SHA256: {sha_hash}'")
149

150

151
if __name__ == '__main__':
152
    main()
153

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.