pytorch-image-models
/
bulk_runner.py
210 строк · 7.4 Кб
1#!/usr/bin/env python3
2""" Bulk Model Script Runner
3
4Run validation or benchmark script in separate process for each model
5
6Benchmark all 'vit*' models:
7python bulk_runner.py --model-list 'vit*' --results-file vit_bench.csv benchmark.py --amp -b 512
8
9Validate all models:
10python bulk_runner.py --model-list all --results-file val.csv --pretrained validate.py /imagenet/validation/ --amp -b 512 --retry
11
12Hacked together by Ross Wightman (https://github.com/rwightman)
13"""
14import argparse
15import os
16import sys
17import csv
18import json
19import subprocess
20import time
21from typing import Callable, List, Tuple, Union
22
23
24from timm.models import is_model, list_models, get_pretrained_cfg
25
26
27parser = argparse.ArgumentParser(description='Per-model process launcher')
28
29# model and results args
30parser.add_argument(
31'--model-list', metavar='NAME', default='',
32help='txt file based list of model names to benchmark')
33parser.add_argument(
34'--results-file', default='', type=str, metavar='FILENAME',
35help='Output csv file for validation results (summary)')
36parser.add_argument(
37'--sort-key', default='', type=str, metavar='COL',
38help='Specify sort key for results csv')
39parser.add_argument(
40"--pretrained", action='store_true',
41help="only run models with pretrained weights")
42
43parser.add_argument(
44"--delay",
45type=float,
46default=0,
47help="Interval, in seconds, to delay between model invocations.",
48)
49parser.add_argument(
50"--start_method", type=str, default="spawn", choices=["spawn", "fork", "forkserver"],
51help="Multiprocessing start method to use when creating workers.",
52)
53parser.add_argument(
54"--no_python",
55help="Skip prepending the script with 'python' - just execute it directly. Useful "
56"when the script is not a Python script.",
57)
58parser.add_argument(
59"-m",
60"--module",
61help="Change each process to interpret the launch script as a Python module, executing "
62"with the same behavior as 'python -m'.",
63)
64
65# positional
66parser.add_argument(
67"script", type=str,
68help="Full path to the program/script to be launched for each model config.",
69)
70parser.add_argument("script_args", nargs=argparse.REMAINDER)
71
72
73def cmd_from_args(args) -> Tuple[Union[Callable, str], List[str]]:
74# If ``args`` not passed, defaults to ``sys.argv[:1]``
75with_python = not args.no_python
76cmd: Union[Callable, str]
77cmd_args = []
78if with_python:
79cmd = os.getenv("PYTHON_EXEC", sys.executable)
80cmd_args.append("-u")
81if args.module:
82cmd_args.append("-m")
83cmd_args.append(args.script)
84else:
85if args.module:
86raise ValueError(
87"Don't use both the '--no_python' flag"
88" and the '--module' flag at the same time."
89)
90cmd = args.script
91cmd_args.extend(args.script_args)
92
93return cmd, cmd_args
94
95
96def main():
97args = parser.parse_args()
98cmd, cmd_args = cmd_from_args(args)
99
100model_cfgs = []
101if args.model_list == 'all':
102model_names = list_models(
103pretrained=args.pretrained, # only include models w/ pretrained checkpoints if set
104)
105model_cfgs = [(n, None) for n in model_names]
106elif args.model_list == 'all_in1k':
107model_names = list_models(pretrained=True)
108model_cfgs = []
109for n in model_names:
110pt_cfg = get_pretrained_cfg(n)
111if getattr(pt_cfg, 'num_classes', 0) == 1000:
112print(n, pt_cfg.num_classes)
113model_cfgs.append((n, None))
114elif args.model_list == 'all_res':
115model_names = list_models()
116model_names += list_models(pretrained=True)
117model_cfgs = set()
118for n in model_names:
119pt_cfg = get_pretrained_cfg(n)
120if pt_cfg is None:
121print(f'Model {n} is missing pretrained cfg, skipping.')
122continue
123n = n.split('.')[0]
124model_cfgs.add((n, pt_cfg.input_size[-1]))
125if pt_cfg.test_input_size is not None:
126model_cfgs.add((n, pt_cfg.test_input_size[-1]))
127model_cfgs = [(n, {'img-size': r}) for n, r in sorted(model_cfgs)]
128elif not is_model(args.model_list):
129# model name doesn't exist, try as wildcard filter
130model_names = list_models(args.model_list)
131model_cfgs = [(n, None) for n in model_names]
132
133if not model_cfgs and os.path.exists(args.model_list):
134with open(args.model_list) as f:
135model_names = [line.rstrip() for line in f]
136model_cfgs = [(n, None) for n in model_names]
137
138if len(model_cfgs):
139results_file = args.results_file or './results.csv'
140results = []
141errors = []
142model_strings = '\n'.join([f'{x[0]}, {x[1]}' for x in model_cfgs])
143print(f"Running script on these models:\n {model_strings}")
144if not args.sort_key:
145if 'benchmark' in args.script:
146if any(['train' in a for a in args.script_args]):
147sort_key = 'train_samples_per_sec'
148else:
149sort_key = 'infer_samples_per_sec'
150else:
151sort_key = 'top1'
152else:
153sort_key = args.sort_key
154print(f'Script: {args.script}, Args: {args.script_args}, Sort key: {sort_key}')
155
156try:
157for m, ax in model_cfgs:
158if not m:
159continue
160args_str = (cmd, *[str(e) for e in cmd_args], '--model', m)
161if ax is not None:
162extra_args = [(f'--{k}', str(v)) for k, v in ax.items()]
163extra_args = [i for t in extra_args for i in t]
164args_str += tuple(extra_args)
165try:
166o = subprocess.check_output(args=args_str).decode('utf-8').split('--result')[-1]
167r = json.loads(o)
168results.append(r)
169except Exception as e:
170# FIXME batch_size retry loop is currently done in either validation.py or benchmark.py
171# for further robustness (but more overhead), we may want to manage that by looping here...
172errors.append(dict(model=m, error=str(e)))
173if args.delay:
174time.sleep(args.delay)
175except KeyboardInterrupt as e:
176pass
177
178errors.extend(list(filter(lambda x: 'error' in x, results)))
179if errors:
180print(f'{len(errors)} models had errors during run.')
181for e in errors:
182if 'model' in e:
183print(f"\t {e['model']} ({e.get('error', 'Unknown')})")
184else:
185print(e)
186
187results = list(filter(lambda x: 'error' not in x, results))
188
189no_sortkey = list(filter(lambda x: sort_key not in x, results))
190if no_sortkey:
191print(f'{len(no_sortkey)} results missing sort key, skipping sort.')
192else:
193results = sorted(results, key=lambda x: x[sort_key], reverse=True)
194
195if len(results):
196print(f'{len(results)} models run successfully. Saving results to {results_file}.')
197write_results(results_file, results)
198
199
200def write_results(results_file, results):
201with open(results_file, mode='w') as cf:
202dw = csv.DictWriter(cf, fieldnames=results[0].keys())
203dw.writeheader()
204for r in results:
205dw.writerow(r)
206cf.flush()
207
208
209if __name__ == '__main__':
210main()
211