pytorch-image-models

Форк
0
/
onnx_validate.py 
110 строк · 4.4 Кб
1
""" ONNX-runtime validation script
2

3
This script was created to verify accuracy and performance of exported ONNX
4
models running with the onnxruntime. It utilizes the PyTorch dataloader/processing
5
pipeline for a fair comparison against the originals.
6

7
Copyright 2020 Ross Wightman
8
"""
9
import argparse
10
import numpy as np
11
import onnxruntime
12
from timm.data import create_loader, resolve_data_config, create_dataset
13
from timm.utils import AverageMeter
14
import time
15

16
parser = argparse.ArgumentParser(description='ONNX Validation')
17
parser.add_argument('data', metavar='DIR',
18
                    help='path to dataset')
19
parser.add_argument('--onnx-input', default='', type=str, metavar='PATH',
20
                    help='path to onnx model/weights file')
21
parser.add_argument('--onnx-output-opt', default='', type=str, metavar='PATH',
22
                    help='path to output optimized onnx graph')
23
parser.add_argument('--profile', action='store_true', default=False,
24
                    help='Enable profiler output.')
25
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
26
                    help='number of data loading workers (default: 2)')
27
parser.add_argument('-b', '--batch-size', default=256, type=int,
28
                    metavar='N', help='mini-batch size (default: 256)')
29
parser.add_argument('--img-size', default=None, type=int,
30
                    metavar='N', help='Input image dimension, uses model default if empty')
31
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
32
                    help='Override mean pixel value of dataset')
33
parser.add_argument('--std', type=float,  nargs='+', default=None, metavar='STD',
34
                    help='Override std deviation of of dataset')
35
parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT',
36
                    help='Override default crop pct of 0.875')
37
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
38
                    help='Image resize interpolation type (overrides model)')
39
parser.add_argument('--print-freq', '-p', default=10, type=int,
40
                    metavar='N', help='print frequency (default: 10)')
41

42

43
def main():
44
    args = parser.parse_args()
45
    args.gpu_id = 0
46

47
    # Set graph optimization level
48
    sess_options = onnxruntime.SessionOptions()
49
    sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
50
    if args.profile:
51
        sess_options.enable_profiling = True
52
    if args.onnx_output_opt:
53
        sess_options.optimized_model_filepath = args.onnx_output_opt
54

55
    session = onnxruntime.InferenceSession(args.onnx_input, sess_options)
56

57
    data_config = resolve_data_config(vars(args))
58
    loader = create_loader(
59
        create_dataset('', args.data),
60
        input_size=data_config['input_size'],
61
        batch_size=args.batch_size,
62
        use_prefetcher=False,
63
        interpolation=data_config['interpolation'],
64
        mean=data_config['mean'],
65
        std=data_config['std'],
66
        num_workers=args.workers,
67
        crop_pct=data_config['crop_pct']
68
    )
69

70
    input_name = session.get_inputs()[0].name
71

72
    batch_time = AverageMeter()
73
    top1 = AverageMeter()
74
    top5 = AverageMeter()
75
    end = time.time()
76
    for i, (input, target) in enumerate(loader):
77
        # run the net and return prediction
78
        output = session.run([], {input_name: input.data.numpy()})
79
        output = output[0]
80

81
        # measure accuracy and record loss
82
        prec1, prec5 = accuracy_np(output, target.numpy())
83
        top1.update(prec1.item(), input.size(0))
84
        top5.update(prec5.item(), input.size(0))
85

86
        # measure elapsed time
87
        batch_time.update(time.time() - end)
88
        end = time.time()
89

90
        if i % args.print_freq == 0:
91
            print(
92
                f'Test: [{i}/{len(loader)}]\t'
93
                f'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {input.size(0) / batch_time.avg:.3f}/s, '
94
                f'{100 * batch_time.avg / input.size(0):.3f} ms/sample) \t'
95
                f'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
96
                f'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'
97
            )
98

99
    print(f' * Prec@1 {top1.avg:.3f} ({100-top1.avg:.3f}) Prec@5 {top5.avg:.3f} ({100.-top5.avg:.3f})')
100

101

102
def accuracy_np(output, target):
103
    max_indices = np.argsort(output, axis=1)[:, ::-1]
104
    top5 = 100 * np.equal(max_indices[:, :5], target[:, np.newaxis]).sum(axis=1).mean()
105
    top1 = 100 * np.equal(max_indices[:, 0], target).mean()
106
    return top1, top5
107

108

109
if __name__ == '__main__':
110
    main()
111

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.