pytorch-image-models
/
avg_checkpoints.py
152 строки · 5.8 Кб
1#!/usr/bin/env python3
2""" Checkpoint Averaging Script
3
4This script averages all model weights for checkpoints in specified path that match
5the specified filter wildcard. All checkpoints must be from the exact same model.
6
7For any hope of decent results, the checkpoints should be from the same or child
8(via resumes) training session. This can be viewed as similar to maintaining running
9EMA (exponential moving average) of the model weights or performing SWA (stochastic
10weight averaging), but post-training.
11
12Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
13"""
14import torch15import argparse16import os17import glob18import hashlib19from timm.models import load_state_dict20try:21import safetensors.torch22_has_safetensors = True23except ImportError:24_has_safetensors = False25
26DEFAULT_OUTPUT = "./averaged.pth"27DEFAULT_SAFE_OUTPUT = "./averaged.safetensors"28
29parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')30parser.add_argument('--input', default='', type=str, metavar='PATH',31help='path to base input folder containing checkpoints')32parser.add_argument('--filter', default='*.pth.tar', type=str, metavar='WILDCARD',33help='checkpoint filter (path wildcard)')34parser.add_argument('--output', default=DEFAULT_OUTPUT, type=str, metavar='PATH',35help=f'Output filename. Defaults to {DEFAULT_SAFE_OUTPUT} when passing --safetensors.')36parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',37help='Force not using ema version of weights (if present)')38parser.add_argument('--no-sort', dest='no_sort', action='store_true',39help='Do not sort and select by checkpoint metric, also makes "n" argument irrelevant')40parser.add_argument('-n', type=int, default=10, metavar='N',41help='Number of checkpoints to average')42parser.add_argument('--safetensors', action='store_true',43help='Save weights using safetensors instead of the default torch way (pickle).')44
45
46def checkpoint_metric(checkpoint_path):47if not checkpoint_path or not os.path.isfile(checkpoint_path):48return {}49print("=> Extracting metric from checkpoint '{}'".format(checkpoint_path))50checkpoint = torch.load(checkpoint_path, map_location='cpu')51metric = None52if 'metric' in checkpoint:53metric = checkpoint['metric']54elif 'metrics' in checkpoint and 'metric_name' in checkpoint:55metrics = checkpoint['metrics']56print(metrics)57metric = metrics[checkpoint['metric_name']]58return metric59
60
61def main():62args = parser.parse_args()63# by default use the EMA weights (if present)64args.use_ema = not args.no_use_ema65# by default sort by checkpoint metric (if present) and avg top n checkpoints66args.sort = not args.no_sort67
68if args.safetensors and args.output == DEFAULT_OUTPUT:69# Default path changes if using safetensors70args.output = DEFAULT_SAFE_OUTPUT71
72output, output_ext = os.path.splitext(args.output)73if not output_ext:74output_ext = ('.safetensors' if args.safetensors else '.pth')75output = output + output_ext76
77if args.safetensors and not output_ext == ".safetensors":78print(79"Warning: saving weights as safetensors but output file extension is not "80f"set to '.safetensors': {args.output}"81)82
83if os.path.exists(output):84print("Error: Output filename ({}) already exists.".format(output))85exit(1)86
87pattern = args.input88if not args.input.endswith(os.path.sep) and not args.filter.startswith(os.path.sep):89pattern += os.path.sep90pattern += args.filter91checkpoints = glob.glob(pattern, recursive=True)92
93if args.sort:94checkpoint_metrics = []95for c in checkpoints:96metric = checkpoint_metric(c)97if metric is not None:98checkpoint_metrics.append((metric, c))99checkpoint_metrics = list(sorted(checkpoint_metrics))100checkpoint_metrics = checkpoint_metrics[-args.n:]101if checkpoint_metrics:102print("Selected checkpoints:")103[print(m, c) for m, c in checkpoint_metrics]104avg_checkpoints = [c for m, c in checkpoint_metrics]105else:106avg_checkpoints = checkpoints107if avg_checkpoints:108print("Selected checkpoints:")109[print(c) for c in checkpoints]110
111if not avg_checkpoints:112print('Error: No checkpoints found to average.')113exit(1)114
115avg_state_dict = {}116avg_counts = {}117for c in avg_checkpoints:118new_state_dict = load_state_dict(c, args.use_ema)119if not new_state_dict:120print(f"Error: Checkpoint ({c}) doesn't exist")121continue122for k, v in new_state_dict.items():123if k not in avg_state_dict:124avg_state_dict[k] = v.clone().to(dtype=torch.float64)125avg_counts[k] = 1126else:127avg_state_dict[k] += v.to(dtype=torch.float64)128avg_counts[k] += 1129
130for k, v in avg_state_dict.items():131v.div_(avg_counts[k])132
133# float32 overflow seems unlikely based on weights seen to date, but who knows134float32_info = torch.finfo(torch.float32)135final_state_dict = {}136for k, v in avg_state_dict.items():137v = v.clamp(float32_info.min, float32_info.max)138final_state_dict[k] = v.to(dtype=torch.float32)139
140if args.safetensors:141assert _has_safetensors, "`pip install safetensors` to use .safetensors"142safetensors.torch.save_file(final_state_dict, output)143else:144torch.save(final_state_dict, output)145
146with open(output, 'rb') as f:147sha_hash = hashlib.sha256(f.read()).hexdigest()148print(f"=> Saved state_dict to '{output}, SHA256: {sha_hash}'")149
150
151if __name__ == '__main__':152main()153