pytorch-image-models

Форк
0
/
benchmark.py 
703 строки · 27.8 Кб
1
#!/usr/bin/env python3
2
""" Model Benchmark Script
3

4
An inference and train step benchmark script for timm models.
5

6
Hacked together by Ross Wightman (https://github.com/rwightman)
7
"""
8
import argparse
9
import csv
10
import json
11
import logging
12
import time
13
from collections import OrderedDict
14
from contextlib import suppress
15
from functools import partial
16

17
import torch
18
import torch.nn as nn
19
import torch.nn.parallel
20

21
from timm.data import resolve_data_config
22
from timm.layers import set_fast_norm
23
from timm.models import create_model, is_model, list_models
24
from timm.optim import create_optimizer_v2
25
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry, ParseKwargs,\
26
    reparameterize_model
27

28
has_apex = False
29
try:
30
    from apex import amp
31
    has_apex = True
32
except ImportError:
33
    pass
34

35
has_native_amp = False
36
try:
37
    if getattr(torch.cuda.amp, 'autocast') is not None:
38
        has_native_amp = True
39
except AttributeError:
40
    pass
41

42
try:
43
    from deepspeed.profiling.flops_profiler import get_model_profile
44
    has_deepspeed_profiling = True
45
except ImportError as e:
46
    has_deepspeed_profiling = False
47

48
try:
49
    from fvcore.nn import FlopCountAnalysis, flop_count_str, ActivationCountAnalysis
50
    has_fvcore_profiling = True
51
except ImportError as e:
52
    FlopCountAnalysis = None
53
    has_fvcore_profiling = False
54

55
try:
56
    from functorch.compile import memory_efficient_fusion
57
    has_functorch = True
58
except ImportError as e:
59
    has_functorch = False
60

61
has_compile = hasattr(torch, 'compile')
62

63
if torch.cuda.is_available():
64
    torch.backends.cuda.matmul.allow_tf32 = True
65
    torch.backends.cudnn.benchmark = True
66
_logger = logging.getLogger('validate')
67

68

69
parser = argparse.ArgumentParser(description='PyTorch Benchmark')
70

71
# benchmark specific args
72
parser.add_argument('--model-list', metavar='NAME', default='',
73
                    help='txt file based list of model names to benchmark')
74
parser.add_argument('--bench', default='both', type=str,
75
                    help="Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'both'")
76
parser.add_argument('--detail', action='store_true', default=False,
77
                    help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False')
78
parser.add_argument('--no-retry', action='store_true', default=False,
79
                    help='Do not decay batch size and retry on error.')
80
parser.add_argument('--results-file', default='', type=str,
81
                    help='Output csv file for validation results (summary)')
82
parser.add_argument('--results-format', default='csv', type=str,
83
                    help='Format for results file one of (csv, json) (default: csv).')
84
parser.add_argument('--num-warm-iter', default=10, type=int,
85
                    help='Number of warmup iterations (default: 10)')
86
parser.add_argument('--num-bench-iter', default=40, type=int,
87
                    help='Number of benchmark iterations (default: 40)')
88
parser.add_argument('--device', default='cuda', type=str,
89
                    help="device to run benchmark on")
90

91
# common inference / train args
92
parser.add_argument('--model', '-m', metavar='NAME', default='resnet50',
93
                    help='model architecture (default: resnet50)')
94
parser.add_argument('-b', '--batch-size', default=256, type=int,
95
                    metavar='N', help='mini-batch size (default: 256)')
96
parser.add_argument('--img-size', default=None, type=int,
97
                    metavar='N', help='Input image dimension, uses model default if empty')
98
parser.add_argument('--input-size', default=None, nargs=3, type=int,
99
                    metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
100
parser.add_argument('--use-train-size', action='store_true', default=False,
101
                    help='Run inference at train size, not test-input-size if it exists.')
102
parser.add_argument('--num-classes', type=int, default=None,
103
                    help='Number classes in dataset')
104
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
105
                    help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
106
parser.add_argument('--channels-last', action='store_true', default=False,
107
                    help='Use channels_last memory layout')
108
parser.add_argument('--grad-checkpointing', action='store_true', default=False,
109
                    help='Enable gradient checkpointing through model blocks/stages')
110
parser.add_argument('--amp', action='store_true', default=False,
111
                    help='use PyTorch Native AMP for mixed precision training. Overrides --precision arg.')
112
parser.add_argument('--amp-dtype', default='float16', type=str,
113
                    help='lower precision AMP dtype (default: float16). Overrides --precision arg if args.amp True.')
114
parser.add_argument('--precision', default='float32', type=str,
115
                    help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)')
116
parser.add_argument('--fuser', default='', type=str,
117
                    help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
118
parser.add_argument('--fast-norm', default=False, action='store_true',
119
                    help='enable experimental fast-norm')
120
parser.add_argument('--reparam', default=False, action='store_true',
121
                    help='Reparameterize model')
122
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
123

124
# codegen (model compilation) options
125
scripting_group = parser.add_mutually_exclusive_group()
126
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
127
                             help='convert model torchscript for inference')
128
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
129
                             help="Enable compilation w/ specified backend (default: inductor).")
130
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
131
                             help="Enable AOT Autograd optimization.")
132

133
# train optimizer parameters
134
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
135
                    help='Optimizer (default: "sgd"')
136
parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
137
                    help='Optimizer Epsilon (default: None, use opt default)')
138
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
139
                    help='Optimizer Betas (default: None, use opt default)')
140
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
141
                    help='Optimizer momentum (default: 0.9)')
142
parser.add_argument('--weight-decay', type=float, default=0.0001,
143
                    help='weight decay (default: 0.0001)')
144
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
145
                    help='Clip gradient norm (default: None, no clipping)')
146
parser.add_argument('--clip-mode', type=str, default='norm',
147
                    help='Gradient clipping mode. One of ("norm", "value", "agc")')
148

149

150
# model regularization / loss params that impact model or loss fn
151
parser.add_argument('--smoothing', type=float, default=0.1,
152
                    help='Label smoothing (default: 0.1)')
153
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
154
                    help='Dropout rate (default: 0.)')
155
parser.add_argument('--drop-path', type=float, default=None, metavar='PCT',
156
                    help='Drop path rate (default: None)')
157
parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
158
                    help='Drop block rate (default: None)')
159

160

161
def timestamp(sync=False):
162
    return time.perf_counter()
163

164

165
def cuda_timestamp(sync=False, device=None):
166
    if sync:
167
        torch.cuda.synchronize(device=device)
168
    return time.perf_counter()
169

170

171
def count_params(model: nn.Module):
172
    return sum([m.numel() for m in model.parameters()])
173

174

175
def resolve_precision(precision: str):
176
    assert precision in ('amp', 'amp_bfloat16', 'float16', 'bfloat16', 'float32')
177
    amp_dtype = None  # amp disabled
178
    model_dtype = torch.float32
179
    data_dtype = torch.float32
180
    if precision == 'amp':
181
        amp_dtype = torch.float16
182
    elif precision == 'amp_bfloat16':
183
        amp_dtype = torch.bfloat16
184
    elif precision == 'float16':
185
        model_dtype = torch.float16
186
        data_dtype = torch.float16
187
    elif precision == 'bfloat16':
188
        model_dtype = torch.bfloat16
189
        data_dtype = torch.bfloat16
190
    return amp_dtype, model_dtype, data_dtype
191

192

193
def profile_deepspeed(model, input_size=(3, 224, 224), batch_size=1, detailed=False):
194
    _, macs, _ = get_model_profile(
195
        model=model,
196
        input_shape=(batch_size,) + input_size,  # input shape/resolution
197
        print_profile=detailed,  # prints the model graph with the measured profile attached to each module
198
        detailed=detailed,  # print the detailed profile
199
        warm_up=10,  # the number of warm-ups before measuring the time of each module
200
        as_string=False,  # print raw numbers (e.g. 1000) or as human-readable strings (e.g. 1k)
201
        output_file=None,  # path to the output file. If None, the profiler prints to stdout.
202
        ignore_modules=None)  # the list of modules to ignore in the profiling
203
    return macs, 0  # no activation count in DS
204

205

206
def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False, force_cpu=False):
207
    if force_cpu:
208
        model = model.to('cpu')
209
    device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
210
    example_input = torch.ones((batch_size,) + input_size, device=device, dtype=dtype)
211
    fca = FlopCountAnalysis(model, example_input)
212
    aca = ActivationCountAnalysis(model, example_input)
213
    if detailed:
214
        fcs = flop_count_str(fca)
215
        print(fcs)
216
    return fca.total(), aca.total()
217

218

219
class BenchmarkRunner:
220
    def __init__(
221
            self,
222
            model_name,
223
            detail=False,
224
            device='cuda',
225
            torchscript=False,
226
            torchcompile=None,
227
            aot_autograd=False,
228
            reparam=False,
229
            precision='float32',
230
            fuser='',
231
            num_warm_iter=10,
232
            num_bench_iter=50,
233
            use_train_size=False,
234
            **kwargs
235
    ):
236
        self.model_name = model_name
237
        self.detail = detail
238
        self.device = device
239
        self.amp_dtype, self.model_dtype, self.data_dtype = resolve_precision(precision)
240
        self.channels_last = kwargs.pop('channels_last', False)
241
        if self.amp_dtype is not None:
242
            self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=self.amp_dtype)
243
        else:
244
            self.amp_autocast = suppress
245

246
        if fuser:
247
            set_jit_fuser(fuser)
248
        self.model = create_model(
249
            model_name,
250
            num_classes=kwargs.pop('num_classes', None),
251
            in_chans=3,
252
            global_pool=kwargs.pop('gp', 'fast'),
253
            scriptable=torchscript,
254
            drop_rate=kwargs.pop('drop', 0.),
255
            drop_path_rate=kwargs.pop('drop_path', None),
256
            drop_block_rate=kwargs.pop('drop_block', None),
257
            **kwargs.pop('model_kwargs', {}),
258
        )
259
        if reparam:
260
            self.model = reparameterize_model(self.model)
261
        self.model.to(
262
            device=self.device,
263
            dtype=self.model_dtype,
264
            memory_format=torch.channels_last if self.channels_last else None,
265
        )
266
        self.num_classes = self.model.num_classes
267
        self.param_count = count_params(self.model)
268
        _logger.info('Model %s created, param count: %d' % (model_name, self.param_count))
269

270
        data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size)
271
        self.input_size = data_config['input_size']
272
        self.batch_size = kwargs.pop('batch_size', 256)
273

274
        self.compiled = False
275
        if torchscript:
276
            self.model = torch.jit.script(self.model)
277
            self.compiled = True
278
        elif torchcompile:
279
            assert has_compile, 'A version of torch w/ torch.compile() is required, possibly a nightly.'
280
            torch._dynamo.reset()
281
            self.model = torch.compile(self.model, backend=torchcompile)
282
            self.compiled = True
283
        elif aot_autograd:
284
            assert has_functorch, "functorch is needed for --aot-autograd"
285
            self.model = memory_efficient_fusion(self.model)
286
            self.compiled = True
287

288
        self.example_inputs = None
289
        self.num_warm_iter = num_warm_iter
290
        self.num_bench_iter = num_bench_iter
291
        self.log_freq = num_bench_iter // 5
292
        if 'cuda' in self.device:
293
            self.time_fn = partial(cuda_timestamp, device=self.device)
294
        else:
295
            self.time_fn = timestamp
296

297
    def _init_input(self):
298
        self.example_inputs = torch.randn(
299
            (self.batch_size,) + self.input_size, device=self.device, dtype=self.data_dtype)
300
        if self.channels_last:
301
            self.example_inputs = self.example_inputs.contiguous(memory_format=torch.channels_last)
302

303

304
class InferenceBenchmarkRunner(BenchmarkRunner):
305

306
    def __init__(
307
            self,
308
            model_name,
309
            device='cuda',
310
            torchscript=False,
311
            **kwargs
312
    ):
313
        super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs)
314
        self.model.eval()
315

316
    def run(self):
317
        def _step():
318
            t_step_start = self.time_fn()
319
            with self.amp_autocast():
320
                output = self.model(self.example_inputs)
321
            t_step_end = self.time_fn(True)
322
            return t_step_end - t_step_start
323

324
        _logger.info(
325
            f'Running inference benchmark on {self.model_name} for {self.num_bench_iter} steps w/ '
326
            f'input size {self.input_size} and batch size {self.batch_size}.')
327

328
        with torch.no_grad():
329
            self._init_input()
330

331
            for _ in range(self.num_warm_iter):
332
                _step()
333

334
            total_step = 0.
335
            num_samples = 0
336
            t_run_start = self.time_fn()
337
            for i in range(self.num_bench_iter):
338
                delta_fwd = _step()
339
                total_step += delta_fwd
340
                num_samples += self.batch_size
341
                num_steps = i + 1
342
                if num_steps % self.log_freq == 0:
343
                    _logger.info(
344
                        f"Infer [{num_steps}/{self.num_bench_iter}]."
345
                        f" {num_samples / total_step:0.2f} samples/sec."
346
                        f" {1000 * total_step / num_steps:0.3f} ms/step.")
347
            t_run_end = self.time_fn(True)
348
            t_run_elapsed = t_run_end - t_run_start
349

350
        results = dict(
351
            samples_per_sec=round(num_samples / t_run_elapsed, 2),
352
            step_time=round(1000 * total_step / self.num_bench_iter, 3),
353
            batch_size=self.batch_size,
354
            img_size=self.input_size[-1],
355
            param_count=round(self.param_count / 1e6, 2),
356
        )
357

358
        retries = 0 if self.compiled else 2  # skip profiling if model is scripted
359
        while retries:
360
            retries -= 1
361
            try:
362
                if has_deepspeed_profiling:
363
                    macs, _ = profile_deepspeed(self.model, self.input_size)
364
                    results['gmacs'] = round(macs / 1e9, 2)
365
                elif has_fvcore_profiling:
366
                    macs, activations = profile_fvcore(self.model, self.input_size, force_cpu=not retries)
367
                    results['gmacs'] = round(macs / 1e9, 2)
368
                    results['macts'] = round(activations / 1e6, 2)
369
            except RuntimeError as e:
370
                pass
371

372
        _logger.info(
373
            f"Inference benchmark of {self.model_name} done. "
374
            f"{results['samples_per_sec']:.2f} samples/sec, {results['step_time']:.2f} ms/step")
375

376
        return results
377

378

379
class TrainBenchmarkRunner(BenchmarkRunner):
380

381
    def __init__(
382
            self,
383
            model_name,
384
            device='cuda',
385
            torchscript=False,
386
            **kwargs
387
    ):
388
        super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs)
389
        self.model.train()
390

391
        self.loss = nn.CrossEntropyLoss().to(self.device)
392
        self.target_shape = tuple()
393

394
        self.optimizer = create_optimizer_v2(
395
            self.model,
396
            opt=kwargs.pop('opt', 'sgd'),
397
            lr=kwargs.pop('lr', 1e-4))
398

399
        if kwargs.pop('grad_checkpointing', False):
400
            self.model.set_grad_checkpointing()
401

402
    def _gen_target(self, batch_size):
403
        return torch.empty(
404
            (batch_size,) + self.target_shape, device=self.device, dtype=torch.long).random_(self.num_classes)
405

406
    def run(self):
407
        def _step(detail=False):
408
            self.optimizer.zero_grad()  # can this be ignored?
409
            t_start = self.time_fn()
410
            t_fwd_end = t_start
411
            t_bwd_end = t_start
412
            with self.amp_autocast():
413
                output = self.model(self.example_inputs)
414
                if isinstance(output, tuple):
415
                    output = output[0]
416
                if detail:
417
                    t_fwd_end = self.time_fn(True)
418
                target = self._gen_target(output.shape[0])
419
                self.loss(output, target).backward()
420
                if detail:
421
                    t_bwd_end = self.time_fn(True)
422
            self.optimizer.step()
423
            t_end = self.time_fn(True)
424
            if detail:
425
                delta_fwd = t_fwd_end - t_start
426
                delta_bwd = t_bwd_end - t_fwd_end
427
                delta_opt = t_end - t_bwd_end
428
                return delta_fwd, delta_bwd, delta_opt
429
            else:
430
                delta_step = t_end - t_start
431
                return delta_step
432

433
        _logger.info(
434
            f'Running train benchmark on {self.model_name} for {self.num_bench_iter} steps w/ '
435
            f'input size {self.input_size} and batch size {self.batch_size}.')
436

437
        self._init_input()
438

439
        for _ in range(self.num_warm_iter):
440
            _step()
441

442
        t_run_start = self.time_fn()
443
        if self.detail:
444
            total_fwd = 0.
445
            total_bwd = 0.
446
            total_opt = 0.
447
            num_samples = 0
448
            for i in range(self.num_bench_iter):
449
                delta_fwd, delta_bwd, delta_opt = _step(True)
450
                num_samples += self.batch_size
451
                total_fwd += delta_fwd
452
                total_bwd += delta_bwd
453
                total_opt += delta_opt
454
                num_steps = (i + 1)
455
                if num_steps % self.log_freq == 0:
456
                    total_step = total_fwd + total_bwd + total_opt
457
                    _logger.info(
458
                        f"Train [{num_steps}/{self.num_bench_iter}]."
459
                        f" {num_samples / total_step:0.2f} samples/sec."
460
                        f" {1000 * total_fwd / num_steps:0.3f} ms/step fwd,"
461
                        f" {1000 * total_bwd / num_steps:0.3f} ms/step bwd,"
462
                        f" {1000 * total_opt / num_steps:0.3f} ms/step opt."
463
                    )
464
            total_step = total_fwd + total_bwd + total_opt
465
            t_run_elapsed = self.time_fn() - t_run_start
466
            results = dict(
467
                samples_per_sec=round(num_samples / t_run_elapsed, 2),
468
                step_time=round(1000 * total_step / self.num_bench_iter, 3),
469
                fwd_time=round(1000 * total_fwd / self.num_bench_iter, 3),
470
                bwd_time=round(1000 * total_bwd / self.num_bench_iter, 3),
471
                opt_time=round(1000 * total_opt / self.num_bench_iter, 3),
472
                batch_size=self.batch_size,
473
                img_size=self.input_size[-1],
474
                param_count=round(self.param_count / 1e6, 2),
475
            )
476
        else:
477
            total_step = 0.
478
            num_samples = 0
479
            for i in range(self.num_bench_iter):
480
                delta_step = _step(False)
481
                num_samples += self.batch_size
482
                total_step += delta_step
483
                num_steps = (i + 1)
484
                if num_steps % self.log_freq == 0:
485
                    _logger.info(
486
                        f"Train [{num_steps}/{self.num_bench_iter}]."
487
                        f" {num_samples / total_step:0.2f} samples/sec."
488
                        f" {1000 * total_step / num_steps:0.3f} ms/step.")
489
            t_run_elapsed = self.time_fn() - t_run_start
490
            results = dict(
491
                samples_per_sec=round(num_samples / t_run_elapsed, 2),
492
                step_time=round(1000 * total_step / self.num_bench_iter, 3),
493
                batch_size=self.batch_size,
494
                img_size=self.input_size[-1],
495
                param_count=round(self.param_count / 1e6, 2),
496
            )
497

498
        _logger.info(
499
            f"Train benchmark of {self.model_name} done. "
500
            f"{results['samples_per_sec']:.2f} samples/sec, {results['step_time']:.2f} ms/sample")
501

502
        return results
503

504

505
class ProfileRunner(BenchmarkRunner):
506

507
    def __init__(self, model_name, device='cuda', profiler='', **kwargs):
508
        super().__init__(model_name=model_name, device=device, **kwargs)
509
        if not profiler:
510
            if has_deepspeed_profiling:
511
                profiler = 'deepspeed'
512
            elif has_fvcore_profiling:
513
                profiler = 'fvcore'
514
        assert profiler, "One of deepspeed or fvcore needs to be installed for profiling to work."
515
        self.profiler = profiler
516
        self.model.eval()
517

518
    def run(self):
519
        _logger.info(
520
            f'Running profiler on {self.model_name} w/ '
521
            f'input size {self.input_size} and batch size {self.batch_size}.')
522

523
        macs = 0
524
        activations = 0
525
        if self.profiler == 'deepspeed':
526
            macs, _ = profile_deepspeed(self.model, self.input_size, batch_size=self.batch_size, detailed=True)
527
        elif self.profiler == 'fvcore':
528
            macs, activations = profile_fvcore(self.model, self.input_size, batch_size=self.batch_size, detailed=True)
529

530
        results = dict(
531
            gmacs=round(macs / 1e9, 2),
532
            macts=round(activations / 1e6, 2),
533
            batch_size=self.batch_size,
534
            img_size=self.input_size[-1],
535
            param_count=round(self.param_count / 1e6, 2),
536
        )
537

538
        _logger.info(
539
            f"Profile of {self.model_name} done. "
540
            f"{results['gmacs']:.2f} GMACs, {results['param_count']:.2f} M params.")
541

542
        return results
543

544

545
def _try_run(
546
        model_name,
547
        bench_fn,
548
        bench_kwargs,
549
        initial_batch_size,
550
        no_batch_size_retry=False
551
):
552
    batch_size = initial_batch_size
553
    results = dict()
554
    error_str = 'Unknown'
555
    while batch_size:
556
        try:
557
            torch.cuda.empty_cache()
558
            bench = bench_fn(model_name=model_name, batch_size=batch_size, **bench_kwargs)
559
            results = bench.run()
560
            return results
561
        except RuntimeError as e:
562
            error_str = str(e)
563
            _logger.error(f'"{error_str}" while running benchmark.')
564
            if not check_batch_size_retry(error_str):
565
                _logger.error(f'Unrecoverable error encountered while benchmarking {model_name}, skipping.')
566
                break
567
            if no_batch_size_retry:
568
                break
569
        batch_size = decay_batch_step(batch_size)
570
        _logger.warning(f'Reducing batch size to {batch_size} for retry.')
571
    results['error'] = error_str
572
    return results
573

574

575
def benchmark(args):
576
    if args.amp:
577
        _logger.warning("Overriding precision to 'amp' since --amp flag set.")
578
        args.precision = 'amp' if args.amp_dtype == 'float16' else '_'.join(['amp', args.amp_dtype])
579
    _logger.info(f'Benchmarking in {args.precision} precision. '
580
                 f'{"NHWC" if args.channels_last else "NCHW"} layout. '
581
                 f'torchscript {"enabled" if args.torchscript else "disabled"}')
582

583
    bench_kwargs = vars(args).copy()
584
    bench_kwargs.pop('amp')
585
    model = bench_kwargs.pop('model')
586
    batch_size = bench_kwargs.pop('batch_size')
587

588
    bench_fns = (InferenceBenchmarkRunner,)
589
    prefixes = ('infer',)
590
    if args.bench == 'both':
591
        bench_fns = (
592
            InferenceBenchmarkRunner,
593
            TrainBenchmarkRunner
594
        )
595
        prefixes = ('infer', 'train')
596
    elif args.bench == 'train':
597
        bench_fns = TrainBenchmarkRunner,
598
        prefixes = 'train',
599
    elif args.bench.startswith('profile'):
600
        # specific profiler used if included in bench mode string, otherwise default to deepspeed, fallback to fvcore
601
        if 'deepspeed' in args.bench:
602
            assert has_deepspeed_profiling, "deepspeed must be installed to use deepspeed flop counter"
603
            bench_kwargs['profiler'] = 'deepspeed'
604
        elif 'fvcore' in args.bench:
605
            assert has_fvcore_profiling, "fvcore must be installed to use fvcore flop counter"
606
            bench_kwargs['profiler'] = 'fvcore'
607
        bench_fns = ProfileRunner,
608
        batch_size = 1
609

610
    model_results = OrderedDict(model=model)
611
    for prefix, bench_fn in zip(prefixes, bench_fns):
612
        run_results = _try_run(
613
            model,
614
            bench_fn,
615
            bench_kwargs=bench_kwargs,
616
            initial_batch_size=batch_size,
617
            no_batch_size_retry=args.no_retry,
618
        )
619
        if prefix and 'error' not in run_results:
620
            run_results = {'_'.join([prefix, k]): v for k, v in run_results.items()}
621
        model_results.update(run_results)
622
        if 'error' in run_results:
623
            break
624
    if 'error' not in model_results:
625
        param_count = model_results.pop('infer_param_count', model_results.pop('train_param_count', 0))
626
        model_results.setdefault('param_count', param_count)
627
        model_results.pop('train_param_count', 0)
628
    return model_results
629

630

631
def main():
632
    setup_default_logging()
633
    args = parser.parse_args()
634
    model_cfgs = []
635
    model_names = []
636

637
    if args.fast_norm:
638
        set_fast_norm()
639

640
    if args.model_list:
641
        args.model = ''
642
        with open(args.model_list) as f:
643
            model_names = [line.rstrip() for line in f]
644
        model_cfgs = [(n, None) for n in model_names]
645
    elif args.model == 'all':
646
        # validate all models in a list of names with pretrained checkpoints
647
        args.pretrained = True
648
        model_names = list_models(pretrained=True, exclude_filters=['*in21k'])
649
        model_cfgs = [(n, None) for n in model_names]
650
    elif not is_model(args.model):
651
        # model name doesn't exist, try as wildcard filter
652
        model_names = list_models(args.model)
653
        model_cfgs = [(n, None) for n in model_names]
654

655
    if len(model_cfgs):
656
        _logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
657
        results = []
658
        try:
659
            for m, _ in model_cfgs:
660
                if not m:
661
                    continue
662
                args.model = m
663
                r = benchmark(args)
664
                if r:
665
                    results.append(r)
666
                time.sleep(10)
667
        except KeyboardInterrupt as e:
668
            pass
669
        sort_key = 'infer_samples_per_sec'
670
        if 'train' in args.bench:
671
            sort_key = 'train_samples_per_sec'
672
        elif 'profile' in args.bench:
673
            sort_key = 'infer_gmacs'
674
        results = filter(lambda x: sort_key in x, results)
675
        results = sorted(results, key=lambda x: x[sort_key], reverse=True)
676
    else:
677
        results = benchmark(args)
678

679
    if args.results_file:
680
        write_results(args.results_file, results, format=args.results_format)
681

682
    # output results in JSON to stdout w/ delimiter for runner script
683
    print(f'--result\n{json.dumps(results, indent=4)}')
684

685

686
def write_results(results_file, results, format='csv'):
687
    with open(results_file, mode='w') as cf:
688
        if format == 'json':
689
            json.dump(results, cf, indent=4)
690
        else:
691
            if not isinstance(results, (list, tuple)):
692
                results = [results]
693
            if not results:
694
                return
695
            dw = csv.DictWriter(cf, fieldnames=results[0].keys())
696
            dw.writeheader()
697
            for r in results:
698
                dw.writerow(r)
699
            cf.flush()
700

701

702
if __name__ == '__main__':
703
    main()
704

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.