pytorch-image-models
/
benchmark.py
703 строки · 27.8 Кб
1#!/usr/bin/env python3
2""" Model Benchmark Script
3
4An inference and train step benchmark script for timm models.
5
6Hacked together by Ross Wightman (https://github.com/rwightman)
7"""
8import argparse
9import csv
10import json
11import logging
12import time
13from collections import OrderedDict
14from contextlib import suppress
15from functools import partial
16
17import torch
18import torch.nn as nn
19import torch.nn.parallel
20
21from timm.data import resolve_data_config
22from timm.layers import set_fast_norm
23from timm.models import create_model, is_model, list_models
24from timm.optim import create_optimizer_v2
25from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry, ParseKwargs,\
26reparameterize_model
27
28has_apex = False
29try:
30from apex import amp
31has_apex = True
32except ImportError:
33pass
34
35has_native_amp = False
36try:
37if getattr(torch.cuda.amp, 'autocast') is not None:
38has_native_amp = True
39except AttributeError:
40pass
41
42try:
43from deepspeed.profiling.flops_profiler import get_model_profile
44has_deepspeed_profiling = True
45except ImportError as e:
46has_deepspeed_profiling = False
47
48try:
49from fvcore.nn import FlopCountAnalysis, flop_count_str, ActivationCountAnalysis
50has_fvcore_profiling = True
51except ImportError as e:
52FlopCountAnalysis = None
53has_fvcore_profiling = False
54
55try:
56from functorch.compile import memory_efficient_fusion
57has_functorch = True
58except ImportError as e:
59has_functorch = False
60
61has_compile = hasattr(torch, 'compile')
62
63if torch.cuda.is_available():
64torch.backends.cuda.matmul.allow_tf32 = True
65torch.backends.cudnn.benchmark = True
66_logger = logging.getLogger('validate')
67
68
69parser = argparse.ArgumentParser(description='PyTorch Benchmark')
70
71# benchmark specific args
72parser.add_argument('--model-list', metavar='NAME', default='',
73help='txt file based list of model names to benchmark')
74parser.add_argument('--bench', default='both', type=str,
75help="Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'both'")
76parser.add_argument('--detail', action='store_true', default=False,
77help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False')
78parser.add_argument('--no-retry', action='store_true', default=False,
79help='Do not decay batch size and retry on error.')
80parser.add_argument('--results-file', default='', type=str,
81help='Output csv file for validation results (summary)')
82parser.add_argument('--results-format', default='csv', type=str,
83help='Format for results file one of (csv, json) (default: csv).')
84parser.add_argument('--num-warm-iter', default=10, type=int,
85help='Number of warmup iterations (default: 10)')
86parser.add_argument('--num-bench-iter', default=40, type=int,
87help='Number of benchmark iterations (default: 40)')
88parser.add_argument('--device', default='cuda', type=str,
89help="device to run benchmark on")
90
91# common inference / train args
92parser.add_argument('--model', '-m', metavar='NAME', default='resnet50',
93help='model architecture (default: resnet50)')
94parser.add_argument('-b', '--batch-size', default=256, type=int,
95metavar='N', help='mini-batch size (default: 256)')
96parser.add_argument('--img-size', default=None, type=int,
97metavar='N', help='Input image dimension, uses model default if empty')
98parser.add_argument('--input-size', default=None, nargs=3, type=int,
99metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
100parser.add_argument('--use-train-size', action='store_true', default=False,
101help='Run inference at train size, not test-input-size if it exists.')
102parser.add_argument('--num-classes', type=int, default=None,
103help='Number classes in dataset')
104parser.add_argument('--gp', default=None, type=str, metavar='POOL',
105help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
106parser.add_argument('--channels-last', action='store_true', default=False,
107help='Use channels_last memory layout')
108parser.add_argument('--grad-checkpointing', action='store_true', default=False,
109help='Enable gradient checkpointing through model blocks/stages')
110parser.add_argument('--amp', action='store_true', default=False,
111help='use PyTorch Native AMP for mixed precision training. Overrides --precision arg.')
112parser.add_argument('--amp-dtype', default='float16', type=str,
113help='lower precision AMP dtype (default: float16). Overrides --precision arg if args.amp True.')
114parser.add_argument('--precision', default='float32', type=str,
115help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)')
116parser.add_argument('--fuser', default='', type=str,
117help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
118parser.add_argument('--fast-norm', default=False, action='store_true',
119help='enable experimental fast-norm')
120parser.add_argument('--reparam', default=False, action='store_true',
121help='Reparameterize model')
122parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
123
124# codegen (model compilation) options
125scripting_group = parser.add_mutually_exclusive_group()
126scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
127help='convert model torchscript for inference')
128scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
129help="Enable compilation w/ specified backend (default: inductor).")
130scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
131help="Enable AOT Autograd optimization.")
132
133# train optimizer parameters
134parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
135help='Optimizer (default: "sgd"')
136parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
137help='Optimizer Epsilon (default: None, use opt default)')
138parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
139help='Optimizer Betas (default: None, use opt default)')
140parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
141help='Optimizer momentum (default: 0.9)')
142parser.add_argument('--weight-decay', type=float, default=0.0001,
143help='weight decay (default: 0.0001)')
144parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
145help='Clip gradient norm (default: None, no clipping)')
146parser.add_argument('--clip-mode', type=str, default='norm',
147help='Gradient clipping mode. One of ("norm", "value", "agc")')
148
149
150# model regularization / loss params that impact model or loss fn
151parser.add_argument('--smoothing', type=float, default=0.1,
152help='Label smoothing (default: 0.1)')
153parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
154help='Dropout rate (default: 0.)')
155parser.add_argument('--drop-path', type=float, default=None, metavar='PCT',
156help='Drop path rate (default: None)')
157parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
158help='Drop block rate (default: None)')
159
160
161def timestamp(sync=False):
162return time.perf_counter()
163
164
165def cuda_timestamp(sync=False, device=None):
166if sync:
167torch.cuda.synchronize(device=device)
168return time.perf_counter()
169
170
171def count_params(model: nn.Module):
172return sum([m.numel() for m in model.parameters()])
173
174
175def resolve_precision(precision: str):
176assert precision in ('amp', 'amp_bfloat16', 'float16', 'bfloat16', 'float32')
177amp_dtype = None # amp disabled
178model_dtype = torch.float32
179data_dtype = torch.float32
180if precision == 'amp':
181amp_dtype = torch.float16
182elif precision == 'amp_bfloat16':
183amp_dtype = torch.bfloat16
184elif precision == 'float16':
185model_dtype = torch.float16
186data_dtype = torch.float16
187elif precision == 'bfloat16':
188model_dtype = torch.bfloat16
189data_dtype = torch.bfloat16
190return amp_dtype, model_dtype, data_dtype
191
192
193def profile_deepspeed(model, input_size=(3, 224, 224), batch_size=1, detailed=False):
194_, macs, _ = get_model_profile(
195model=model,
196input_shape=(batch_size,) + input_size, # input shape/resolution
197print_profile=detailed, # prints the model graph with the measured profile attached to each module
198detailed=detailed, # print the detailed profile
199warm_up=10, # the number of warm-ups before measuring the time of each module
200as_string=False, # print raw numbers (e.g. 1000) or as human-readable strings (e.g. 1k)
201output_file=None, # path to the output file. If None, the profiler prints to stdout.
202ignore_modules=None) # the list of modules to ignore in the profiling
203return macs, 0 # no activation count in DS
204
205
206def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False, force_cpu=False):
207if force_cpu:
208model = model.to('cpu')
209device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
210example_input = torch.ones((batch_size,) + input_size, device=device, dtype=dtype)
211fca = FlopCountAnalysis(model, example_input)
212aca = ActivationCountAnalysis(model, example_input)
213if detailed:
214fcs = flop_count_str(fca)
215print(fcs)
216return fca.total(), aca.total()
217
218
219class BenchmarkRunner:
220def __init__(
221self,
222model_name,
223detail=False,
224device='cuda',
225torchscript=False,
226torchcompile=None,
227aot_autograd=False,
228reparam=False,
229precision='float32',
230fuser='',
231num_warm_iter=10,
232num_bench_iter=50,
233use_train_size=False,
234**kwargs
235):
236self.model_name = model_name
237self.detail = detail
238self.device = device
239self.amp_dtype, self.model_dtype, self.data_dtype = resolve_precision(precision)
240self.channels_last = kwargs.pop('channels_last', False)
241if self.amp_dtype is not None:
242self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=self.amp_dtype)
243else:
244self.amp_autocast = suppress
245
246if fuser:
247set_jit_fuser(fuser)
248self.model = create_model(
249model_name,
250num_classes=kwargs.pop('num_classes', None),
251in_chans=3,
252global_pool=kwargs.pop('gp', 'fast'),
253scriptable=torchscript,
254drop_rate=kwargs.pop('drop', 0.),
255drop_path_rate=kwargs.pop('drop_path', None),
256drop_block_rate=kwargs.pop('drop_block', None),
257**kwargs.pop('model_kwargs', {}),
258)
259if reparam:
260self.model = reparameterize_model(self.model)
261self.model.to(
262device=self.device,
263dtype=self.model_dtype,
264memory_format=torch.channels_last if self.channels_last else None,
265)
266self.num_classes = self.model.num_classes
267self.param_count = count_params(self.model)
268_logger.info('Model %s created, param count: %d' % (model_name, self.param_count))
269
270data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size)
271self.input_size = data_config['input_size']
272self.batch_size = kwargs.pop('batch_size', 256)
273
274self.compiled = False
275if torchscript:
276self.model = torch.jit.script(self.model)
277self.compiled = True
278elif torchcompile:
279assert has_compile, 'A version of torch w/ torch.compile() is required, possibly a nightly.'
280torch._dynamo.reset()
281self.model = torch.compile(self.model, backend=torchcompile)
282self.compiled = True
283elif aot_autograd:
284assert has_functorch, "functorch is needed for --aot-autograd"
285self.model = memory_efficient_fusion(self.model)
286self.compiled = True
287
288self.example_inputs = None
289self.num_warm_iter = num_warm_iter
290self.num_bench_iter = num_bench_iter
291self.log_freq = num_bench_iter // 5
292if 'cuda' in self.device:
293self.time_fn = partial(cuda_timestamp, device=self.device)
294else:
295self.time_fn = timestamp
296
297def _init_input(self):
298self.example_inputs = torch.randn(
299(self.batch_size,) + self.input_size, device=self.device, dtype=self.data_dtype)
300if self.channels_last:
301self.example_inputs = self.example_inputs.contiguous(memory_format=torch.channels_last)
302
303
304class InferenceBenchmarkRunner(BenchmarkRunner):
305
306def __init__(
307self,
308model_name,
309device='cuda',
310torchscript=False,
311**kwargs
312):
313super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs)
314self.model.eval()
315
316def run(self):
317def _step():
318t_step_start = self.time_fn()
319with self.amp_autocast():
320output = self.model(self.example_inputs)
321t_step_end = self.time_fn(True)
322return t_step_end - t_step_start
323
324_logger.info(
325f'Running inference benchmark on {self.model_name} for {self.num_bench_iter} steps w/ '
326f'input size {self.input_size} and batch size {self.batch_size}.')
327
328with torch.no_grad():
329self._init_input()
330
331for _ in range(self.num_warm_iter):
332_step()
333
334total_step = 0.
335num_samples = 0
336t_run_start = self.time_fn()
337for i in range(self.num_bench_iter):
338delta_fwd = _step()
339total_step += delta_fwd
340num_samples += self.batch_size
341num_steps = i + 1
342if num_steps % self.log_freq == 0:
343_logger.info(
344f"Infer [{num_steps}/{self.num_bench_iter}]."
345f" {num_samples / total_step:0.2f} samples/sec."
346f" {1000 * total_step / num_steps:0.3f} ms/step.")
347t_run_end = self.time_fn(True)
348t_run_elapsed = t_run_end - t_run_start
349
350results = dict(
351samples_per_sec=round(num_samples / t_run_elapsed, 2),
352step_time=round(1000 * total_step / self.num_bench_iter, 3),
353batch_size=self.batch_size,
354img_size=self.input_size[-1],
355param_count=round(self.param_count / 1e6, 2),
356)
357
358retries = 0 if self.compiled else 2 # skip profiling if model is scripted
359while retries:
360retries -= 1
361try:
362if has_deepspeed_profiling:
363macs, _ = profile_deepspeed(self.model, self.input_size)
364results['gmacs'] = round(macs / 1e9, 2)
365elif has_fvcore_profiling:
366macs, activations = profile_fvcore(self.model, self.input_size, force_cpu=not retries)
367results['gmacs'] = round(macs / 1e9, 2)
368results['macts'] = round(activations / 1e6, 2)
369except RuntimeError as e:
370pass
371
372_logger.info(
373f"Inference benchmark of {self.model_name} done. "
374f"{results['samples_per_sec']:.2f} samples/sec, {results['step_time']:.2f} ms/step")
375
376return results
377
378
379class TrainBenchmarkRunner(BenchmarkRunner):
380
381def __init__(
382self,
383model_name,
384device='cuda',
385torchscript=False,
386**kwargs
387):
388super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs)
389self.model.train()
390
391self.loss = nn.CrossEntropyLoss().to(self.device)
392self.target_shape = tuple()
393
394self.optimizer = create_optimizer_v2(
395self.model,
396opt=kwargs.pop('opt', 'sgd'),
397lr=kwargs.pop('lr', 1e-4))
398
399if kwargs.pop('grad_checkpointing', False):
400self.model.set_grad_checkpointing()
401
402def _gen_target(self, batch_size):
403return torch.empty(
404(batch_size,) + self.target_shape, device=self.device, dtype=torch.long).random_(self.num_classes)
405
406def run(self):
407def _step(detail=False):
408self.optimizer.zero_grad() # can this be ignored?
409t_start = self.time_fn()
410t_fwd_end = t_start
411t_bwd_end = t_start
412with self.amp_autocast():
413output = self.model(self.example_inputs)
414if isinstance(output, tuple):
415output = output[0]
416if detail:
417t_fwd_end = self.time_fn(True)
418target = self._gen_target(output.shape[0])
419self.loss(output, target).backward()
420if detail:
421t_bwd_end = self.time_fn(True)
422self.optimizer.step()
423t_end = self.time_fn(True)
424if detail:
425delta_fwd = t_fwd_end - t_start
426delta_bwd = t_bwd_end - t_fwd_end
427delta_opt = t_end - t_bwd_end
428return delta_fwd, delta_bwd, delta_opt
429else:
430delta_step = t_end - t_start
431return delta_step
432
433_logger.info(
434f'Running train benchmark on {self.model_name} for {self.num_bench_iter} steps w/ '
435f'input size {self.input_size} and batch size {self.batch_size}.')
436
437self._init_input()
438
439for _ in range(self.num_warm_iter):
440_step()
441
442t_run_start = self.time_fn()
443if self.detail:
444total_fwd = 0.
445total_bwd = 0.
446total_opt = 0.
447num_samples = 0
448for i in range(self.num_bench_iter):
449delta_fwd, delta_bwd, delta_opt = _step(True)
450num_samples += self.batch_size
451total_fwd += delta_fwd
452total_bwd += delta_bwd
453total_opt += delta_opt
454num_steps = (i + 1)
455if num_steps % self.log_freq == 0:
456total_step = total_fwd + total_bwd + total_opt
457_logger.info(
458f"Train [{num_steps}/{self.num_bench_iter}]."
459f" {num_samples / total_step:0.2f} samples/sec."
460f" {1000 * total_fwd / num_steps:0.3f} ms/step fwd,"
461f" {1000 * total_bwd / num_steps:0.3f} ms/step bwd,"
462f" {1000 * total_opt / num_steps:0.3f} ms/step opt."
463)
464total_step = total_fwd + total_bwd + total_opt
465t_run_elapsed = self.time_fn() - t_run_start
466results = dict(
467samples_per_sec=round(num_samples / t_run_elapsed, 2),
468step_time=round(1000 * total_step / self.num_bench_iter, 3),
469fwd_time=round(1000 * total_fwd / self.num_bench_iter, 3),
470bwd_time=round(1000 * total_bwd / self.num_bench_iter, 3),
471opt_time=round(1000 * total_opt / self.num_bench_iter, 3),
472batch_size=self.batch_size,
473img_size=self.input_size[-1],
474param_count=round(self.param_count / 1e6, 2),
475)
476else:
477total_step = 0.
478num_samples = 0
479for i in range(self.num_bench_iter):
480delta_step = _step(False)
481num_samples += self.batch_size
482total_step += delta_step
483num_steps = (i + 1)
484if num_steps % self.log_freq == 0:
485_logger.info(
486f"Train [{num_steps}/{self.num_bench_iter}]."
487f" {num_samples / total_step:0.2f} samples/sec."
488f" {1000 * total_step / num_steps:0.3f} ms/step.")
489t_run_elapsed = self.time_fn() - t_run_start
490results = dict(
491samples_per_sec=round(num_samples / t_run_elapsed, 2),
492step_time=round(1000 * total_step / self.num_bench_iter, 3),
493batch_size=self.batch_size,
494img_size=self.input_size[-1],
495param_count=round(self.param_count / 1e6, 2),
496)
497
498_logger.info(
499f"Train benchmark of {self.model_name} done. "
500f"{results['samples_per_sec']:.2f} samples/sec, {results['step_time']:.2f} ms/sample")
501
502return results
503
504
505class ProfileRunner(BenchmarkRunner):
506
507def __init__(self, model_name, device='cuda', profiler='', **kwargs):
508super().__init__(model_name=model_name, device=device, **kwargs)
509if not profiler:
510if has_deepspeed_profiling:
511profiler = 'deepspeed'
512elif has_fvcore_profiling:
513profiler = 'fvcore'
514assert profiler, "One of deepspeed or fvcore needs to be installed for profiling to work."
515self.profiler = profiler
516self.model.eval()
517
518def run(self):
519_logger.info(
520f'Running profiler on {self.model_name} w/ '
521f'input size {self.input_size} and batch size {self.batch_size}.')
522
523macs = 0
524activations = 0
525if self.profiler == 'deepspeed':
526macs, _ = profile_deepspeed(self.model, self.input_size, batch_size=self.batch_size, detailed=True)
527elif self.profiler == 'fvcore':
528macs, activations = profile_fvcore(self.model, self.input_size, batch_size=self.batch_size, detailed=True)
529
530results = dict(
531gmacs=round(macs / 1e9, 2),
532macts=round(activations / 1e6, 2),
533batch_size=self.batch_size,
534img_size=self.input_size[-1],
535param_count=round(self.param_count / 1e6, 2),
536)
537
538_logger.info(
539f"Profile of {self.model_name} done. "
540f"{results['gmacs']:.2f} GMACs, {results['param_count']:.2f} M params.")
541
542return results
543
544
545def _try_run(
546model_name,
547bench_fn,
548bench_kwargs,
549initial_batch_size,
550no_batch_size_retry=False
551):
552batch_size = initial_batch_size
553results = dict()
554error_str = 'Unknown'
555while batch_size:
556try:
557torch.cuda.empty_cache()
558bench = bench_fn(model_name=model_name, batch_size=batch_size, **bench_kwargs)
559results = bench.run()
560return results
561except RuntimeError as e:
562error_str = str(e)
563_logger.error(f'"{error_str}" while running benchmark.')
564if not check_batch_size_retry(error_str):
565_logger.error(f'Unrecoverable error encountered while benchmarking {model_name}, skipping.')
566break
567if no_batch_size_retry:
568break
569batch_size = decay_batch_step(batch_size)
570_logger.warning(f'Reducing batch size to {batch_size} for retry.')
571results['error'] = error_str
572return results
573
574
575def benchmark(args):
576if args.amp:
577_logger.warning("Overriding precision to 'amp' since --amp flag set.")
578args.precision = 'amp' if args.amp_dtype == 'float16' else '_'.join(['amp', args.amp_dtype])
579_logger.info(f'Benchmarking in {args.precision} precision. '
580f'{"NHWC" if args.channels_last else "NCHW"} layout. '
581f'torchscript {"enabled" if args.torchscript else "disabled"}')
582
583bench_kwargs = vars(args).copy()
584bench_kwargs.pop('amp')
585model = bench_kwargs.pop('model')
586batch_size = bench_kwargs.pop('batch_size')
587
588bench_fns = (InferenceBenchmarkRunner,)
589prefixes = ('infer',)
590if args.bench == 'both':
591bench_fns = (
592InferenceBenchmarkRunner,
593TrainBenchmarkRunner
594)
595prefixes = ('infer', 'train')
596elif args.bench == 'train':
597bench_fns = TrainBenchmarkRunner,
598prefixes = 'train',
599elif args.bench.startswith('profile'):
600# specific profiler used if included in bench mode string, otherwise default to deepspeed, fallback to fvcore
601if 'deepspeed' in args.bench:
602assert has_deepspeed_profiling, "deepspeed must be installed to use deepspeed flop counter"
603bench_kwargs['profiler'] = 'deepspeed'
604elif 'fvcore' in args.bench:
605assert has_fvcore_profiling, "fvcore must be installed to use fvcore flop counter"
606bench_kwargs['profiler'] = 'fvcore'
607bench_fns = ProfileRunner,
608batch_size = 1
609
610model_results = OrderedDict(model=model)
611for prefix, bench_fn in zip(prefixes, bench_fns):
612run_results = _try_run(
613model,
614bench_fn,
615bench_kwargs=bench_kwargs,
616initial_batch_size=batch_size,
617no_batch_size_retry=args.no_retry,
618)
619if prefix and 'error' not in run_results:
620run_results = {'_'.join([prefix, k]): v for k, v in run_results.items()}
621model_results.update(run_results)
622if 'error' in run_results:
623break
624if 'error' not in model_results:
625param_count = model_results.pop('infer_param_count', model_results.pop('train_param_count', 0))
626model_results.setdefault('param_count', param_count)
627model_results.pop('train_param_count', 0)
628return model_results
629
630
631def main():
632setup_default_logging()
633args = parser.parse_args()
634model_cfgs = []
635model_names = []
636
637if args.fast_norm:
638set_fast_norm()
639
640if args.model_list:
641args.model = ''
642with open(args.model_list) as f:
643model_names = [line.rstrip() for line in f]
644model_cfgs = [(n, None) for n in model_names]
645elif args.model == 'all':
646# validate all models in a list of names with pretrained checkpoints
647args.pretrained = True
648model_names = list_models(pretrained=True, exclude_filters=['*in21k'])
649model_cfgs = [(n, None) for n in model_names]
650elif not is_model(args.model):
651# model name doesn't exist, try as wildcard filter
652model_names = list_models(args.model)
653model_cfgs = [(n, None) for n in model_names]
654
655if len(model_cfgs):
656_logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
657results = []
658try:
659for m, _ in model_cfgs:
660if not m:
661continue
662args.model = m
663r = benchmark(args)
664if r:
665results.append(r)
666time.sleep(10)
667except KeyboardInterrupt as e:
668pass
669sort_key = 'infer_samples_per_sec'
670if 'train' in args.bench:
671sort_key = 'train_samples_per_sec'
672elif 'profile' in args.bench:
673sort_key = 'infer_gmacs'
674results = filter(lambda x: sort_key in x, results)
675results = sorted(results, key=lambda x: x[sort_key], reverse=True)
676else:
677results = benchmark(args)
678
679if args.results_file:
680write_results(args.results_file, results, format=args.results_format)
681
682# output results in JSON to stdout w/ delimiter for runner script
683print(f'--result\n{json.dumps(results, indent=4)}')
684
685
686def write_results(results_file, results, format='csv'):
687with open(results_file, mode='w') as cf:
688if format == 'json':
689json.dump(results, cf, indent=4)
690else:
691if not isinstance(results, (list, tuple)):
692results = [results]
693if not results:
694return
695dw = csv.DictWriter(cf, fieldnames=results[0].keys())
696dw.writeheader()
697for r in results:
698dw.writerow(r)
699cf.flush()
700
701
702if __name__ == '__main__':
703main()
704