pytorch-image-models

Форк
0
/
bulk_runner.py 
210 строк · 7.4 Кб
1
#!/usr/bin/env python3
2
""" Bulk Model Script Runner
3

4
Run validation or benchmark script in separate process for each model
5

6
Benchmark all 'vit*' models:
7
python bulk_runner.py  --model-list 'vit*' --results-file vit_bench.csv benchmark.py --amp -b 512
8

9
Validate all models:
10
python bulk_runner.py  --model-list all --results-file val.csv --pretrained validate.py /imagenet/validation/ --amp -b 512 --retry
11

12
Hacked together by Ross Wightman (https://github.com/rwightman)
13
"""
14
import argparse
15
import os
16
import sys
17
import csv
18
import json
19
import subprocess
20
import time
21
from typing import Callable, List, Tuple, Union
22

23

24
from timm.models import is_model, list_models, get_pretrained_cfg
25

26

27
parser = argparse.ArgumentParser(description='Per-model process launcher')
28

29
# model and results args
30
parser.add_argument(
31
    '--model-list', metavar='NAME', default='',
32
    help='txt file based list of model names to benchmark')
33
parser.add_argument(
34
    '--results-file', default='', type=str, metavar='FILENAME',
35
    help='Output csv file for validation results (summary)')
36
parser.add_argument(
37
    '--sort-key', default='', type=str, metavar='COL',
38
    help='Specify sort key for results csv')
39
parser.add_argument(
40
    "--pretrained", action='store_true',
41
    help="only run models with pretrained weights")
42

43
parser.add_argument(
44
    "--delay",
45
    type=float,
46
    default=0,
47
    help="Interval, in seconds, to delay between model invocations.",
48
)
49
parser.add_argument(
50
    "--start_method", type=str, default="spawn", choices=["spawn", "fork", "forkserver"],
51
    help="Multiprocessing start method to use when creating workers.",
52
)
53
parser.add_argument(
54
    "--no_python",
55
    help="Skip prepending the script with 'python' - just execute it directly. Useful "
56
         "when the script is not a Python script.",
57
)
58
parser.add_argument(
59
    "-m",
60
    "--module",
61
    help="Change each process to interpret the launch script as a Python module, executing "
62
         "with the same behavior as 'python -m'.",
63
)
64

65
# positional
66
parser.add_argument(
67
    "script", type=str,
68
    help="Full path to the program/script to be launched for each model config.",
69
)
70
parser.add_argument("script_args", nargs=argparse.REMAINDER)
71

72

73
def cmd_from_args(args) -> Tuple[Union[Callable, str], List[str]]:
74
    # If ``args`` not passed, defaults to ``sys.argv[:1]``
75
    with_python = not args.no_python
76
    cmd: Union[Callable, str]
77
    cmd_args = []
78
    if with_python:
79
        cmd = os.getenv("PYTHON_EXEC", sys.executable)
80
        cmd_args.append("-u")
81
        if args.module:
82
            cmd_args.append("-m")
83
        cmd_args.append(args.script)
84
    else:
85
        if args.module:
86
            raise ValueError(
87
                "Don't use both the '--no_python' flag"
88
                " and the '--module' flag at the same time."
89
            )
90
        cmd = args.script
91
    cmd_args.extend(args.script_args)
92

93
    return cmd, cmd_args
94

95

96
def main():
97
    args = parser.parse_args()
98
    cmd, cmd_args = cmd_from_args(args)
99

100
    model_cfgs = []
101
    if args.model_list == 'all':
102
        model_names = list_models(
103
            pretrained=args.pretrained,  # only include models w/ pretrained checkpoints if set
104
        )
105
        model_cfgs = [(n, None) for n in model_names]
106
    elif args.model_list == 'all_in1k':
107
        model_names = list_models(pretrained=True)
108
        model_cfgs = []
109
        for n in model_names:
110
            pt_cfg = get_pretrained_cfg(n)
111
            if getattr(pt_cfg, 'num_classes', 0) == 1000:
112
                print(n, pt_cfg.num_classes)
113
                model_cfgs.append((n, None))
114
    elif args.model_list == 'all_res':
115
        model_names = list_models()
116
        model_names += list_models(pretrained=True)
117
        model_cfgs = set()
118
        for n in model_names:
119
            pt_cfg = get_pretrained_cfg(n)
120
            if pt_cfg is None:
121
                print(f'Model {n} is missing pretrained cfg, skipping.')
122
                continue
123
            n = n.split('.')[0]
124
            model_cfgs.add((n, pt_cfg.input_size[-1]))
125
            if pt_cfg.test_input_size is not None:
126
                model_cfgs.add((n, pt_cfg.test_input_size[-1]))
127
        model_cfgs = [(n, {'img-size': r}) for n, r in sorted(model_cfgs)]
128
    elif not is_model(args.model_list):
129
        # model name doesn't exist, try as wildcard filter
130
        model_names = list_models(args.model_list)
131
        model_cfgs = [(n, None) for n in model_names]
132

133
    if not model_cfgs and os.path.exists(args.model_list):
134
        with open(args.model_list) as f:
135
            model_names = [line.rstrip() for line in f]
136
            model_cfgs = [(n, None) for n in model_names]
137

138
    if len(model_cfgs):
139
        results_file = args.results_file or './results.csv'
140
        results = []
141
        errors = []
142
        model_strings = '\n'.join([f'{x[0]}, {x[1]}' for x in model_cfgs])
143
        print(f"Running script on these models:\n {model_strings}")
144
        if not args.sort_key:
145
            if 'benchmark' in args.script:
146
                if any(['train' in a for a in args.script_args]):
147
                    sort_key = 'train_samples_per_sec'
148
                else:
149
                    sort_key = 'infer_samples_per_sec'
150
            else:
151
                sort_key = 'top1'
152
        else:
153
            sort_key = args.sort_key
154
        print(f'Script: {args.script}, Args: {args.script_args}, Sort key: {sort_key}')
155

156
        try:
157
            for m, ax in model_cfgs:
158
                if not m:
159
                    continue
160
                args_str = (cmd, *[str(e) for e in cmd_args], '--model', m)
161
                if ax is not None:
162
                    extra_args = [(f'--{k}', str(v)) for k, v in ax.items()]
163
                    extra_args = [i for t in extra_args for i in t]
164
                    args_str += tuple(extra_args)
165
                try:
166
                    o = subprocess.check_output(args=args_str).decode('utf-8').split('--result')[-1]
167
                    r = json.loads(o)
168
                    results.append(r)
169
                except Exception as e:
170
                    # FIXME batch_size retry loop is currently done in either validation.py or benchmark.py
171
                    # for further robustness (but more overhead), we may want to manage that by looping here...
172
                    errors.append(dict(model=m, error=str(e)))
173
                if args.delay:
174
                    time.sleep(args.delay)
175
        except KeyboardInterrupt as e:
176
            pass
177

178
        errors.extend(list(filter(lambda x: 'error' in x, results)))
179
        if errors:
180
            print(f'{len(errors)} models had errors during run.')
181
            for e in errors:
182
                if 'model' in e:
183
                    print(f"\t {e['model']} ({e.get('error', 'Unknown')})")
184
                else:
185
                    print(e)
186

187
        results = list(filter(lambda x: 'error' not in x, results))
188

189
        no_sortkey = list(filter(lambda x: sort_key not in x, results))
190
        if no_sortkey:
191
            print(f'{len(no_sortkey)} results missing sort key, skipping sort.')
192
        else:
193
            results = sorted(results, key=lambda x: x[sort_key], reverse=True)
194

195
        if len(results):
196
            print(f'{len(results)} models run successfully. Saving results to {results_file}.')
197
            write_results(results_file, results)
198

199

200
def write_results(results_file, results):
201
    with open(results_file, mode='w') as cf:
202
        dw = csv.DictWriter(cf, fieldnames=results[0].keys())
203
        dw.writeheader()
204
        for r in results:
205
            dw.writerow(r)
206
        cf.flush()
207

208

209
if __name__ == '__main__':
210
    main()
211

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.