pytorch-image-models
/
validate.py
497 строк · 20.6 Кб
1#!/usr/bin/env python3
2""" ImageNet Validation Script
3
4This is intended to be a lean and easily modifiable ImageNet validation script for evaluating pretrained
5models or training checkpoints against ImageNet or similarly organized image datasets. It prioritizes
6canonical PyTorch, standard Python style, and good performance. Repurpose as you see fit.
7
8Hacked together by Ross Wightman (https://github.com/rwightman)
9"""
10import argparse11import csv12import glob13import json14import logging15import os16import time17from collections import OrderedDict18from contextlib import suppress19from functools import partial20
21import torch22import torch.nn as nn23import torch.nn.parallel24
25from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet26from timm.layers import apply_test_time_pool, set_fast_norm27from timm.models import create_model, load_checkpoint, is_model, list_models28from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser, \29decay_batch_step, check_batch_size_retry, ParseKwargs, reparameterize_model30
31try:32from apex import amp33has_apex = True34except ImportError:35has_apex = False36
37has_native_amp = False38try:39if getattr(torch.cuda.amp, 'autocast') is not None:40has_native_amp = True41except AttributeError:42pass43
44try:45from functorch.compile import memory_efficient_fusion46has_functorch = True47except ImportError as e:48has_functorch = False49
50has_compile = hasattr(torch, 'compile')51
52_logger = logging.getLogger('validate')53
54
55parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')56parser.add_argument('data', nargs='?', metavar='DIR', const=None,57help='path to dataset (*deprecated*, use --data-dir)')58parser.add_argument('--data-dir', metavar='DIR',59help='path to dataset (root dir)')60parser.add_argument('--dataset', metavar='NAME', default='',61help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')62parser.add_argument('--split', metavar='NAME', default='validation',63help='dataset split (default: validation)')64parser.add_argument('--num-samples', default=None, type=int,65metavar='N', help='Manually specify num samples in dataset split, for IterableDatasets.')66parser.add_argument('--dataset-download', action='store_true', default=False,67help='Allow download of dataset for torch/ and tfds/ datasets that support it.')68parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',69help='path to class to idx mapping file (default: "")')70parser.add_argument('--input-key', default=None, type=str,71help='Dataset key for input images.')72parser.add_argument('--input-img-mode', default=None, type=str,73help='Dataset image conversion mode for input images.')74parser.add_argument('--target-key', default=None, type=str,75help='Dataset key for target labels.')76
77parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',78help='model architecture (default: dpn92)')79parser.add_argument('--pretrained', dest='pretrained', action='store_true',80help='use pre-trained model')81parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',82help='number of data loading workers (default: 4)')83parser.add_argument('-b', '--batch-size', default=256, type=int,84metavar='N', help='mini-batch size (default: 256)')85parser.add_argument('--img-size', default=None, type=int,86metavar='N', help='Input image dimension, uses model default if empty')87parser.add_argument('--in-chans', type=int, default=None, metavar='N',88help='Image input channels (default: None => 3)')89parser.add_argument('--input-size', default=None, nargs=3, type=int,90metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')91parser.add_argument('--use-train-size', action='store_true', default=False,92help='force use of train input size, even when test size is specified in pretrained cfg')93parser.add_argument('--crop-pct', default=None, type=float,94metavar='N', help='Input image center crop pct')95parser.add_argument('--crop-mode', default=None, type=str,96metavar='N', help='Input image crop mode (squash, border, center). Model default if None.')97parser.add_argument('--crop-border-pixels', type=int, default=None,98help='Crop pixels from image border.')99parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',100help='Override mean pixel value of dataset')101parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',102help='Override std deviation of of dataset')103parser.add_argument('--interpolation', default='', type=str, metavar='NAME',104help='Image resize interpolation type (overrides model)')105parser.add_argument('--num-classes', type=int, default=None,106help='Number classes in dataset')107parser.add_argument('--gp', default=None, type=str, metavar='POOL',108help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')109parser.add_argument('--log-freq', default=10, type=int,110metavar='N', help='batch logging frequency (default: 10)')111parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',112help='path to latest checkpoint (default: none)')113parser.add_argument('--num-gpu', type=int, default=1,114help='Number of GPUS to use')115parser.add_argument('--test-pool', dest='test_pool', action='store_true',116help='enable test time pool')117parser.add_argument('--no-prefetcher', action='store_true', default=False,118help='disable fast prefetcher')119parser.add_argument('--pin-mem', action='store_true', default=False,120help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')121parser.add_argument('--channels-last', action='store_true', default=False,122help='Use channels_last memory layout')123parser.add_argument('--device', default='cuda', type=str,124help="Device (accelerator) to use.")125parser.add_argument('--amp', action='store_true', default=False,126help='use NVIDIA Apex AMP or Native AMP for mixed precision training')127parser.add_argument('--amp-dtype', default='float16', type=str,128help='lower precision AMP dtype (default: float16)')129parser.add_argument('--amp-impl', default='native', type=str,130help='AMP impl to use, "native" or "apex" (default: native)')131parser.add_argument('--tf-preprocessing', action='store_true', default=False,132help='Use Tensorflow preprocessing pipeline (require CPU TF installed')133parser.add_argument('--use-ema', dest='use_ema', action='store_true',134help='use ema version of weights if present')135parser.add_argument('--fuser', default='', type=str,136help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")137parser.add_argument('--fast-norm', default=False, action='store_true',138help='enable experimental fast-norm')139parser.add_argument('--reparam', default=False, action='store_true',140help='Reparameterize model')141parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)142
143
144scripting_group = parser.add_mutually_exclusive_group()145scripting_group.add_argument('--torchscript', default=False, action='store_true',146help='torch.jit.script the full model')147scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',148help="Enable compilation w/ specified backend (default: inductor).")149scripting_group.add_argument('--aot-autograd', default=False, action='store_true',150help="Enable AOT Autograd support.")151
152parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',153help='Output csv file for validation results (summary)')154parser.add_argument('--results-format', default='csv', type=str,155help='Format for results file one of (csv, json) (default: csv).')156parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',157help='Real labels JSON file for imagenet evaluation')158parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',159help='Valid label indices txt file for validation of partial label space')160parser.add_argument('--retry', default=False, action='store_true',161help='Enable batch size decay & retry for single model validation')162
163
164def validate(args):165# might as well try to validate something166args.pretrained = args.pretrained or not args.checkpoint167args.prefetcher = not args.no_prefetcher168
169if torch.cuda.is_available():170torch.backends.cuda.matmul.allow_tf32 = True171torch.backends.cudnn.benchmark = True172
173device = torch.device(args.device)174
175# resolve AMP arguments based on PyTorch / Apex availability176use_amp = None177amp_autocast = suppress178if args.amp:179if args.amp_impl == 'apex':180assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'181assert args.amp_dtype == 'float16'182use_amp = 'apex'183_logger.info('Validating in mixed precision with NVIDIA APEX AMP.')184else:185assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'186assert args.amp_dtype in ('float16', 'bfloat16')187use_amp = 'native'188amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16189amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)190_logger.info('Validating in mixed precision with native PyTorch AMP.')191else:192_logger.info('Validating in float32. AMP not enabled.')193
194if args.fuser:195set_jit_fuser(args.fuser)196
197if args.fast_norm:198set_fast_norm()199
200# create model201in_chans = 3202if args.in_chans is not None:203in_chans = args.in_chans204elif args.input_size is not None:205in_chans = args.input_size[0]206
207model = create_model(208args.model,209pretrained=args.pretrained,210num_classes=args.num_classes,211in_chans=in_chans,212global_pool=args.gp,213scriptable=args.torchscript,214**args.model_kwargs,215)216if args.num_classes is None:217assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'218args.num_classes = model.num_classes219
220if args.checkpoint:221load_checkpoint(model, args.checkpoint, args.use_ema)222
223if args.reparam:224model = reparameterize_model(model)225
226param_count = sum([m.numel() for m in model.parameters()])227_logger.info('Model %s created, param count: %d' % (args.model, param_count))228
229data_config = resolve_data_config(230vars(args),231model=model,232use_test_size=not args.use_train_size,233verbose=True,234)235test_time_pool = False236if args.test_pool:237model, test_time_pool = apply_test_time_pool(model, data_config)238
239model = model.to(device)240if args.channels_last:241model = model.to(memory_format=torch.channels_last)242
243if args.torchscript:244assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'245model = torch.jit.script(model)246elif args.torchcompile:247assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'248torch._dynamo.reset()249model = torch.compile(model, backend=args.torchcompile)250elif args.aot_autograd:251assert has_functorch, "functorch is needed for --aot-autograd"252model = memory_efficient_fusion(model)253
254if use_amp == 'apex':255model = amp.initialize(model, opt_level='O1')256
257if args.num_gpu > 1:258model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))259
260criterion = nn.CrossEntropyLoss().to(device)261
262root_dir = args.data or args.data_dir263if args.input_img_mode is None:264input_img_mode = 'RGB' if data_config['input_size'][0] == 3 else 'L'265else:266input_img_mode = args.input_img_mode267dataset = create_dataset(268root=root_dir,269name=args.dataset,270split=args.split,271download=args.dataset_download,272load_bytes=args.tf_preprocessing,273class_map=args.class_map,274num_samples=args.num_samples,275input_key=args.input_key,276input_img_mode=input_img_mode,277target_key=args.target_key,278)279
280if args.valid_labels:281with open(args.valid_labels, 'r') as f:282valid_labels = [int(line.rstrip()) for line in f]283else:284valid_labels = None285
286if args.real_labels:287real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels)288else:289real_labels = None290
291crop_pct = 1.0 if test_time_pool else data_config['crop_pct']292loader = create_loader(293dataset,294input_size=data_config['input_size'],295batch_size=args.batch_size,296use_prefetcher=args.prefetcher,297interpolation=data_config['interpolation'],298mean=data_config['mean'],299std=data_config['std'],300num_workers=args.workers,301crop_pct=crop_pct,302crop_mode=data_config['crop_mode'],303crop_border_pixels=args.crop_border_pixels,304pin_memory=args.pin_mem,305device=device,306tf_preprocessing=args.tf_preprocessing,307)308
309batch_time = AverageMeter()310losses = AverageMeter()311top1 = AverageMeter()312top5 = AverageMeter()313
314model.eval()315with torch.no_grad():316# warmup, reduce variability of first batch time, especially for comparing torchscript vs non317input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device)318if args.channels_last:319input = input.contiguous(memory_format=torch.channels_last)320with amp_autocast():321model(input)322
323end = time.time()324for batch_idx, (input, target) in enumerate(loader):325if args.no_prefetcher:326target = target.to(device)327input = input.to(device)328if args.channels_last:329input = input.contiguous(memory_format=torch.channels_last)330
331# compute output332with amp_autocast():333output = model(input)334
335if valid_labels is not None:336output = output[:, valid_labels]337loss = criterion(output, target)338
339if real_labels is not None:340real_labels.add_result(output)341
342# measure accuracy and record loss343acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))344losses.update(loss.item(), input.size(0))345top1.update(acc1.item(), input.size(0))346top5.update(acc5.item(), input.size(0))347
348# measure elapsed time349batch_time.update(time.time() - end)350end = time.time()351
352if batch_idx % args.log_freq == 0:353_logger.info(354'Test: [{0:>4d}/{1}] '355'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '356'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '357'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '358'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(359batch_idx,360len(loader),361batch_time=batch_time,362rate_avg=input.size(0) / batch_time.avg,363loss=losses,364top1=top1,365top5=top5366)367)368
369if real_labels is not None:370# real labels mode replaces topk values at the end371top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5)372else:373top1a, top5a = top1.avg, top5.avg374results = OrderedDict(375model=args.model,376top1=round(top1a, 4), top1_err=round(100 - top1a, 4),377top5=round(top5a, 4), top5_err=round(100 - top5a, 4),378param_count=round(param_count / 1e6, 2),379img_size=data_config['input_size'][-1],380crop_pct=crop_pct,381interpolation=data_config['interpolation'],382)383
384_logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(385results['top1'], results['top1_err'], results['top5'], results['top5_err']))386
387return results388
389
390def _try_run(args, initial_batch_size):391batch_size = initial_batch_size392results = OrderedDict()393error_str = 'Unknown'394while batch_size:395args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case396try:397if torch.cuda.is_available() and 'cuda' in args.device:398torch.cuda.empty_cache()399results = validate(args)400return results401except RuntimeError as e:402error_str = str(e)403_logger.error(f'"{error_str}" while running validation.')404if not check_batch_size_retry(error_str):405break406batch_size = decay_batch_step(batch_size)407_logger.warning(f'Reducing batch size to {batch_size} for retry.')408results['error'] = error_str409_logger.error(f'{args.model} failed to validate ({error_str}).')410return results411
412
413_NON_IN1K_FILTERS = ['*_in21k', '*_in22k', '*in12k', '*_dino', '*fcmae', '*seer']414
415
416def main():417setup_default_logging()418args = parser.parse_args()419model_cfgs = []420model_names = []421if os.path.isdir(args.checkpoint):422# validate all checkpoints in a path with same model423checkpoints = glob.glob(args.checkpoint + '/*.pth.tar')424checkpoints += glob.glob(args.checkpoint + '/*.pth')425model_names = list_models(args.model)426model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)]427else:428if args.model == 'all':429# validate all models in a list of names with pretrained checkpoints430args.pretrained = True431model_names = list_models(432pretrained=True,433exclude_filters=_NON_IN1K_FILTERS,434)435model_cfgs = [(n, '') for n in model_names]436elif not is_model(args.model):437# model name doesn't exist, try as wildcard filter438model_names = list_models(439args.model,440pretrained=True,441)442model_cfgs = [(n, '') for n in model_names]443
444if not model_cfgs and os.path.isfile(args.model):445with open(args.model) as f:446model_names = [line.rstrip() for line in f]447model_cfgs = [(n, None) for n in model_names if n]448
449if len(model_cfgs):450_logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))451results = []452try:453initial_batch_size = args.batch_size454for m, c in model_cfgs:455args.model = m456args.checkpoint = c457r = _try_run(args, initial_batch_size)458if 'error' in r:459continue460if args.checkpoint:461r['checkpoint'] = args.checkpoint462results.append(r)463except KeyboardInterrupt as e:464pass465results = sorted(results, key=lambda x: x['top1'], reverse=True)466else:467if args.retry:468results = _try_run(args, args.batch_size)469else:470results = validate(args)471
472if args.results_file:473write_results(args.results_file, results, format=args.results_format)474
475# output results in JSON to stdout w/ delimiter for runner script476print(f'--result\n{json.dumps(results, indent=4)}')477
478
479def write_results(results_file, results, format='csv'):480with open(results_file, mode='w') as cf:481if format == 'json':482json.dump(results, cf, indent=4)483else:484if not isinstance(results, (list, tuple)):485results = [results]486if not results:487return488dw = csv.DictWriter(cf, fieldnames=results[0].keys())489dw.writeheader()490for r in results:491dw.writerow(r)492cf.flush()493
494
495
496if __name__ == '__main__':497main()498