pytorch-image-models
/
onnx_validate.py
110 строк · 4.4 Кб
1""" ONNX-runtime validation script
2
3This script was created to verify accuracy and performance of exported ONNX
4models running with the onnxruntime. It utilizes the PyTorch dataloader/processing
5pipeline for a fair comparison against the originals.
6
7Copyright 2020 Ross Wightman
8"""
9import argparse10import numpy as np11import onnxruntime12from timm.data import create_loader, resolve_data_config, create_dataset13from timm.utils import AverageMeter14import time15
16parser = argparse.ArgumentParser(description='ONNX Validation')17parser.add_argument('data', metavar='DIR',18help='path to dataset')19parser.add_argument('--onnx-input', default='', type=str, metavar='PATH',20help='path to onnx model/weights file')21parser.add_argument('--onnx-output-opt', default='', type=str, metavar='PATH',22help='path to output optimized onnx graph')23parser.add_argument('--profile', action='store_true', default=False,24help='Enable profiler output.')25parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',26help='number of data loading workers (default: 2)')27parser.add_argument('-b', '--batch-size', default=256, type=int,28metavar='N', help='mini-batch size (default: 256)')29parser.add_argument('--img-size', default=None, type=int,30metavar='N', help='Input image dimension, uses model default if empty')31parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',32help='Override mean pixel value of dataset')33parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',34help='Override std deviation of of dataset')35parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT',36help='Override default crop pct of 0.875')37parser.add_argument('--interpolation', default='', type=str, metavar='NAME',38help='Image resize interpolation type (overrides model)')39parser.add_argument('--print-freq', '-p', default=10, type=int,40metavar='N', help='print frequency (default: 10)')41
42
43def main():44args = parser.parse_args()45args.gpu_id = 046
47# Set graph optimization level48sess_options = onnxruntime.SessionOptions()49sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL50if args.profile:51sess_options.enable_profiling = True52if args.onnx_output_opt:53sess_options.optimized_model_filepath = args.onnx_output_opt54
55session = onnxruntime.InferenceSession(args.onnx_input, sess_options)56
57data_config = resolve_data_config(vars(args))58loader = create_loader(59create_dataset('', args.data),60input_size=data_config['input_size'],61batch_size=args.batch_size,62use_prefetcher=False,63interpolation=data_config['interpolation'],64mean=data_config['mean'],65std=data_config['std'],66num_workers=args.workers,67crop_pct=data_config['crop_pct']68)69
70input_name = session.get_inputs()[0].name71
72batch_time = AverageMeter()73top1 = AverageMeter()74top5 = AverageMeter()75end = time.time()76for i, (input, target) in enumerate(loader):77# run the net and return prediction78output = session.run([], {input_name: input.data.numpy()})79output = output[0]80
81# measure accuracy and record loss82prec1, prec5 = accuracy_np(output, target.numpy())83top1.update(prec1.item(), input.size(0))84top5.update(prec5.item(), input.size(0))85
86# measure elapsed time87batch_time.update(time.time() - end)88end = time.time()89
90if i % args.print_freq == 0:91print(92f'Test: [{i}/{len(loader)}]\t'93f'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {input.size(0) / batch_time.avg:.3f}/s, '94f'{100 * batch_time.avg / input.size(0):.3f} ms/sample) \t'95f'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'96f'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'97)98
99print(f' * Prec@1 {top1.avg:.3f} ({100-top1.avg:.3f}) Prec@5 {top5.avg:.3f} ({100.-top5.avg:.3f})')100
101
102def accuracy_np(output, target):103max_indices = np.argsort(output, axis=1)[:, ::-1]104top5 = 100 * np.equal(max_indices[:, :5], target[:, np.newaxis]).sum(axis=1).mean()105top1 = 100 * np.equal(max_indices[:, 0], target).mean()106return top1, top5107
108
109if __name__ == '__main__':110main()111