pytorch-image-models

Форк
0
/
validate.py 
497 строк · 20.6 Кб
1
#!/usr/bin/env python3
2
""" ImageNet Validation Script
3

4
This is intended to be a lean and easily modifiable ImageNet validation script for evaluating pretrained
5
models or training checkpoints against ImageNet or similarly organized image datasets. It prioritizes
6
canonical PyTorch, standard Python style, and good performance. Repurpose as you see fit.
7

8
Hacked together by Ross Wightman (https://github.com/rwightman)
9
"""
10
import argparse
11
import csv
12
import glob
13
import json
14
import logging
15
import os
16
import time
17
from collections import OrderedDict
18
from contextlib import suppress
19
from functools import partial
20

21
import torch
22
import torch.nn as nn
23
import torch.nn.parallel
24

25
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
26
from timm.layers import apply_test_time_pool, set_fast_norm
27
from timm.models import create_model, load_checkpoint, is_model, list_models
28
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser, \
29
    decay_batch_step, check_batch_size_retry, ParseKwargs, reparameterize_model
30

31
try:
32
    from apex import amp
33
    has_apex = True
34
except ImportError:
35
    has_apex = False
36

37
has_native_amp = False
38
try:
39
    if getattr(torch.cuda.amp, 'autocast') is not None:
40
        has_native_amp = True
41
except AttributeError:
42
    pass
43

44
try:
45
    from functorch.compile import memory_efficient_fusion
46
    has_functorch = True
47
except ImportError as e:
48
    has_functorch = False
49

50
has_compile = hasattr(torch, 'compile')
51

52
_logger = logging.getLogger('validate')
53

54

55
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
56
parser.add_argument('data', nargs='?', metavar='DIR', const=None,
57
                    help='path to dataset (*deprecated*, use --data-dir)')
58
parser.add_argument('--data-dir', metavar='DIR',
59
                    help='path to dataset (root dir)')
60
parser.add_argument('--dataset', metavar='NAME', default='',
61
                    help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
62
parser.add_argument('--split', metavar='NAME', default='validation',
63
                    help='dataset split (default: validation)')
64
parser.add_argument('--num-samples', default=None, type=int,
65
                    metavar='N', help='Manually specify num samples in dataset split, for IterableDatasets.')
66
parser.add_argument('--dataset-download', action='store_true', default=False,
67
                    help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
68
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
69
                    help='path to class to idx mapping file (default: "")')
70
parser.add_argument('--input-key', default=None, type=str,
71
                   help='Dataset key for input images.')
72
parser.add_argument('--input-img-mode', default=None, type=str,
73
                   help='Dataset image conversion mode for input images.')
74
parser.add_argument('--target-key', default=None, type=str,
75
                   help='Dataset key for target labels.')
76

77
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
78
                    help='model architecture (default: dpn92)')
79
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
80
                    help='use pre-trained model')
81
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
82
                    help='number of data loading workers (default: 4)')
83
parser.add_argument('-b', '--batch-size', default=256, type=int,
84
                    metavar='N', help='mini-batch size (default: 256)')
85
parser.add_argument('--img-size', default=None, type=int,
86
                    metavar='N', help='Input image dimension, uses model default if empty')
87
parser.add_argument('--in-chans', type=int, default=None, metavar='N',
88
                    help='Image input channels (default: None => 3)')
89
parser.add_argument('--input-size', default=None, nargs=3, type=int,
90
                    metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
91
parser.add_argument('--use-train-size', action='store_true', default=False,
92
                    help='force use of train input size, even when test size is specified in pretrained cfg')
93
parser.add_argument('--crop-pct', default=None, type=float,
94
                    metavar='N', help='Input image center crop pct')
95
parser.add_argument('--crop-mode', default=None, type=str,
96
                    metavar='N', help='Input image crop mode (squash, border, center). Model default if None.')
97
parser.add_argument('--crop-border-pixels', type=int, default=None,
98
                    help='Crop pixels from image border.')
99
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
100
                    help='Override mean pixel value of dataset')
101
parser.add_argument('--std', type=float,  nargs='+', default=None, metavar='STD',
102
                    help='Override std deviation of of dataset')
103
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
104
                    help='Image resize interpolation type (overrides model)')
105
parser.add_argument('--num-classes', type=int, default=None,
106
                    help='Number classes in dataset')
107
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
108
                    help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
109
parser.add_argument('--log-freq', default=10, type=int,
110
                    metavar='N', help='batch logging frequency (default: 10)')
111
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
112
                    help='path to latest checkpoint (default: none)')
113
parser.add_argument('--num-gpu', type=int, default=1,
114
                    help='Number of GPUS to use')
115
parser.add_argument('--test-pool', dest='test_pool', action='store_true',
116
                    help='enable test time pool')
117
parser.add_argument('--no-prefetcher', action='store_true', default=False,
118
                    help='disable fast prefetcher')
119
parser.add_argument('--pin-mem', action='store_true', default=False,
120
                    help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
121
parser.add_argument('--channels-last', action='store_true', default=False,
122
                    help='Use channels_last memory layout')
123
parser.add_argument('--device', default='cuda', type=str,
124
                    help="Device (accelerator) to use.")
125
parser.add_argument('--amp', action='store_true', default=False,
126
                    help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
127
parser.add_argument('--amp-dtype', default='float16', type=str,
128
                    help='lower precision AMP dtype (default: float16)')
129
parser.add_argument('--amp-impl', default='native', type=str,
130
                    help='AMP impl to use, "native" or "apex" (default: native)')
131
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
132
                    help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
133
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
134
                    help='use ema version of weights if present')
135
parser.add_argument('--fuser', default='', type=str,
136
                    help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
137
parser.add_argument('--fast-norm', default=False, action='store_true',
138
                    help='enable experimental fast-norm')
139
parser.add_argument('--reparam', default=False, action='store_true',
140
                    help='Reparameterize model')
141
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
142

143

144
scripting_group = parser.add_mutually_exclusive_group()
145
scripting_group.add_argument('--torchscript', default=False, action='store_true',
146
                             help='torch.jit.script the full model')
147
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
148
                             help="Enable compilation w/ specified backend (default: inductor).")
149
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
150
                             help="Enable AOT Autograd support.")
151

152
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
153
                    help='Output csv file for validation results (summary)')
154
parser.add_argument('--results-format', default='csv', type=str,
155
                    help='Format for results file one of (csv, json) (default: csv).')
156
parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
157
                    help='Real labels JSON file for imagenet evaluation')
158
parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',
159
                    help='Valid label indices txt file for validation of partial label space')
160
parser.add_argument('--retry', default=False, action='store_true',
161
                    help='Enable batch size decay & retry for single model validation')
162

163

164
def validate(args):
165
    # might as well try to validate something
166
    args.pretrained = args.pretrained or not args.checkpoint
167
    args.prefetcher = not args.no_prefetcher
168

169
    if torch.cuda.is_available():
170
        torch.backends.cuda.matmul.allow_tf32 = True
171
        torch.backends.cudnn.benchmark = True
172

173
    device = torch.device(args.device)
174

175
    # resolve AMP arguments based on PyTorch / Apex availability
176
    use_amp = None
177
    amp_autocast = suppress
178
    if args.amp:
179
        if args.amp_impl == 'apex':
180
            assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
181
            assert args.amp_dtype == 'float16'
182
            use_amp = 'apex'
183
            _logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
184
        else:
185
            assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
186
            assert args.amp_dtype in ('float16', 'bfloat16')
187
            use_amp = 'native'
188
            amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16
189
            amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
190
            _logger.info('Validating in mixed precision with native PyTorch AMP.')
191
    else:
192
        _logger.info('Validating in float32. AMP not enabled.')
193

194
    if args.fuser:
195
        set_jit_fuser(args.fuser)
196

197
    if args.fast_norm:
198
        set_fast_norm()
199

200
    # create model
201
    in_chans = 3
202
    if args.in_chans is not None:
203
        in_chans = args.in_chans
204
    elif args.input_size is not None:
205
        in_chans = args.input_size[0]
206

207
    model = create_model(
208
        args.model,
209
        pretrained=args.pretrained,
210
        num_classes=args.num_classes,
211
        in_chans=in_chans,
212
        global_pool=args.gp,
213
        scriptable=args.torchscript,
214
        **args.model_kwargs,
215
    )
216
    if args.num_classes is None:
217
        assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
218
        args.num_classes = model.num_classes
219

220
    if args.checkpoint:
221
        load_checkpoint(model, args.checkpoint, args.use_ema)
222

223
    if args.reparam:
224
        model = reparameterize_model(model)
225

226
    param_count = sum([m.numel() for m in model.parameters()])
227
    _logger.info('Model %s created, param count: %d' % (args.model, param_count))
228

229
    data_config = resolve_data_config(
230
        vars(args),
231
        model=model,
232
        use_test_size=not args.use_train_size,
233
        verbose=True,
234
    )
235
    test_time_pool = False
236
    if args.test_pool:
237
        model, test_time_pool = apply_test_time_pool(model, data_config)
238

239
    model = model.to(device)
240
    if args.channels_last:
241
        model = model.to(memory_format=torch.channels_last)
242

243
    if args.torchscript:
244
        assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
245
        model = torch.jit.script(model)
246
    elif args.torchcompile:
247
        assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
248
        torch._dynamo.reset()
249
        model = torch.compile(model, backend=args.torchcompile)
250
    elif args.aot_autograd:
251
        assert has_functorch, "functorch is needed for --aot-autograd"
252
        model = memory_efficient_fusion(model)
253

254
    if use_amp == 'apex':
255
        model = amp.initialize(model, opt_level='O1')
256

257
    if args.num_gpu > 1:
258
        model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
259

260
    criterion = nn.CrossEntropyLoss().to(device)
261

262
    root_dir = args.data or args.data_dir
263
    if args.input_img_mode is None:
264
        input_img_mode = 'RGB' if data_config['input_size'][0] == 3 else 'L'
265
    else:
266
        input_img_mode = args.input_img_mode
267
    dataset = create_dataset(
268
        root=root_dir,
269
        name=args.dataset,
270
        split=args.split,
271
        download=args.dataset_download,
272
        load_bytes=args.tf_preprocessing,
273
        class_map=args.class_map,
274
        num_samples=args.num_samples,
275
        input_key=args.input_key,
276
        input_img_mode=input_img_mode,
277
        target_key=args.target_key,
278
    )
279

280
    if args.valid_labels:
281
        with open(args.valid_labels, 'r') as f:
282
            valid_labels = [int(line.rstrip()) for line in f]
283
    else:
284
        valid_labels = None
285

286
    if args.real_labels:
287
        real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels)
288
    else:
289
        real_labels = None
290

291
    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
292
    loader = create_loader(
293
        dataset,
294
        input_size=data_config['input_size'],
295
        batch_size=args.batch_size,
296
        use_prefetcher=args.prefetcher,
297
        interpolation=data_config['interpolation'],
298
        mean=data_config['mean'],
299
        std=data_config['std'],
300
        num_workers=args.workers,
301
        crop_pct=crop_pct,
302
        crop_mode=data_config['crop_mode'],
303
        crop_border_pixels=args.crop_border_pixels,
304
        pin_memory=args.pin_mem,
305
        device=device,
306
        tf_preprocessing=args.tf_preprocessing,
307
    )
308

309
    batch_time = AverageMeter()
310
    losses = AverageMeter()
311
    top1 = AverageMeter()
312
    top5 = AverageMeter()
313

314
    model.eval()
315
    with torch.no_grad():
316
        # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
317
        input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device)
318
        if args.channels_last:
319
            input = input.contiguous(memory_format=torch.channels_last)
320
        with amp_autocast():
321
            model(input)
322

323
        end = time.time()
324
        for batch_idx, (input, target) in enumerate(loader):
325
            if args.no_prefetcher:
326
                target = target.to(device)
327
                input = input.to(device)
328
            if args.channels_last:
329
                input = input.contiguous(memory_format=torch.channels_last)
330

331
            # compute output
332
            with amp_autocast():
333
                output = model(input)
334

335
                if valid_labels is not None:
336
                    output = output[:, valid_labels]
337
                loss = criterion(output, target)
338

339
            if real_labels is not None:
340
                real_labels.add_result(output)
341

342
            # measure accuracy and record loss
343
            acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
344
            losses.update(loss.item(), input.size(0))
345
            top1.update(acc1.item(), input.size(0))
346
            top5.update(acc5.item(), input.size(0))
347

348
            # measure elapsed time
349
            batch_time.update(time.time() - end)
350
            end = time.time()
351

352
            if batch_idx % args.log_freq == 0:
353
                _logger.info(
354
                    'Test: [{0:>4d}/{1}]  '
355
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
356
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
357
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
358
                    'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
359
                        batch_idx,
360
                        len(loader),
361
                        batch_time=batch_time,
362
                        rate_avg=input.size(0) / batch_time.avg,
363
                        loss=losses,
364
                        top1=top1,
365
                        top5=top5
366
                    )
367
                )
368

369
    if real_labels is not None:
370
        # real labels mode replaces topk values at the end
371
        top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5)
372
    else:
373
        top1a, top5a = top1.avg, top5.avg
374
    results = OrderedDict(
375
        model=args.model,
376
        top1=round(top1a, 4), top1_err=round(100 - top1a, 4),
377
        top5=round(top5a, 4), top5_err=round(100 - top5a, 4),
378
        param_count=round(param_count / 1e6, 2),
379
        img_size=data_config['input_size'][-1],
380
        crop_pct=crop_pct,
381
        interpolation=data_config['interpolation'],
382
    )
383

384
    _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
385
       results['top1'], results['top1_err'], results['top5'], results['top5_err']))
386

387
    return results
388

389

390
def _try_run(args, initial_batch_size):
391
    batch_size = initial_batch_size
392
    results = OrderedDict()
393
    error_str = 'Unknown'
394
    while batch_size:
395
        args.batch_size = batch_size * args.num_gpu  # multiply by num-gpu for DataParallel case
396
        try:
397
            if torch.cuda.is_available() and 'cuda' in args.device:
398
                torch.cuda.empty_cache()
399
            results = validate(args)
400
            return results
401
        except RuntimeError as e:
402
            error_str = str(e)
403
            _logger.error(f'"{error_str}" while running validation.')
404
            if not check_batch_size_retry(error_str):
405
                break
406
        batch_size = decay_batch_step(batch_size)
407
        _logger.warning(f'Reducing batch size to {batch_size} for retry.')
408
    results['error'] = error_str
409
    _logger.error(f'{args.model} failed to validate ({error_str}).')
410
    return results
411

412

413
_NON_IN1K_FILTERS = ['*_in21k', '*_in22k', '*in12k', '*_dino', '*fcmae', '*seer']
414

415

416
def main():
417
    setup_default_logging()
418
    args = parser.parse_args()
419
    model_cfgs = []
420
    model_names = []
421
    if os.path.isdir(args.checkpoint):
422
        # validate all checkpoints in a path with same model
423
        checkpoints = glob.glob(args.checkpoint + '/*.pth.tar')
424
        checkpoints += glob.glob(args.checkpoint + '/*.pth')
425
        model_names = list_models(args.model)
426
        model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)]
427
    else:
428
        if args.model == 'all':
429
            # validate all models in a list of names with pretrained checkpoints
430
            args.pretrained = True
431
            model_names = list_models(
432
                pretrained=True,
433
                exclude_filters=_NON_IN1K_FILTERS,
434
            )
435
            model_cfgs = [(n, '') for n in model_names]
436
        elif not is_model(args.model):
437
            # model name doesn't exist, try as wildcard filter
438
            model_names = list_models(
439
                args.model,
440
                pretrained=True,
441
            )
442
            model_cfgs = [(n, '') for n in model_names]
443

444
        if not model_cfgs and os.path.isfile(args.model):
445
            with open(args.model) as f:
446
                model_names = [line.rstrip() for line in f]
447
            model_cfgs = [(n, None) for n in model_names if n]
448

449
    if len(model_cfgs):
450
        _logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
451
        results = []
452
        try:
453
            initial_batch_size = args.batch_size
454
            for m, c in model_cfgs:
455
                args.model = m
456
                args.checkpoint = c
457
                r = _try_run(args, initial_batch_size)
458
                if 'error' in r:
459
                    continue
460
                if args.checkpoint:
461
                    r['checkpoint'] = args.checkpoint
462
                results.append(r)
463
        except KeyboardInterrupt as e:
464
            pass
465
        results = sorted(results, key=lambda x: x['top1'], reverse=True)
466
    else:
467
        if args.retry:
468
            results = _try_run(args, args.batch_size)
469
        else:
470
            results = validate(args)
471

472
    if args.results_file:
473
        write_results(args.results_file, results, format=args.results_format)
474

475
    # output results in JSON to stdout w/ delimiter for runner script
476
    print(f'--result\n{json.dumps(results, indent=4)}')
477

478

479
def write_results(results_file, results, format='csv'):
480
    with open(results_file, mode='w') as cf:
481
        if format == 'json':
482
            json.dump(results, cf, indent=4)
483
        else:
484
            if not isinstance(results, (list, tuple)):
485
                results = [results]
486
            if not results:
487
                return
488
            dw = csv.DictWriter(cf, fieldnames=results[0].keys())
489
            dw.writeheader()
490
            for r in results:
491
                dw.writerow(r)
492
            cf.flush()
493

494

495

496
if __name__ == '__main__':
497
    main()
498

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.