pytorch-image-models

Форк
0
/
onnx_export.py 
102 строки · 4.6 Кб
1
""" ONNX export script
2

3
Export PyTorch models as ONNX graphs.
4

5
This export script originally started as an adaptation of code snippets found at
6
https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
7

8
The default parameters work with PyTorch 1.6 and ONNX 1.7 and produce an optimal ONNX graph
9
for hosting in the ONNX runtime (see onnx_validate.py). To export an ONNX model compatible
10
with caffe2 (see caffe2_benchmark.py and caffe2_validate.py), the --keep-init and --aten-fallback
11
flags are currently required.
12

13
Older versions of PyTorch/ONNX (tested PyTorch 1.4, ONNX 1.5) do not need extra flags for
14
caffe2 compatibility, but they produce a model that isn't as fast running on ONNX runtime.
15

16
Most new release of PyTorch and ONNX cause some sort of breakage in the export / usage of ONNX models.
17
Please do your research and search ONNX and PyTorch issue tracker before asking me. Thanks.
18

19
Copyright 2020 Ross Wightman
20
"""
21
import argparse
22

23
import timm
24
from timm.utils.model import reparameterize_model
25
from timm.utils.onnx import onnx_export
26

27
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
28
parser.add_argument('output', metavar='ONNX_FILE',
29
                    help='output model filename')
30
parser.add_argument('--model', '-m', metavar='MODEL', default='mobilenetv3_large_100',
31
                    help='model architecture (default: mobilenetv3_large_100)')
32
parser.add_argument('--opset', type=int, default=None,
33
                    help='ONNX opset to use (default: 10)')
34
parser.add_argument('--keep-init', action='store_true', default=False,
35
                    help='Keep initializers as input. Needed for Caffe2 compatible export in newer PyTorch/ONNX.')
36
parser.add_argument('--aten-fallback', action='store_true', default=False,
37
                    help='Fallback to ATEN ops. Helps fix AdaptiveAvgPool issue with Caffe2 in newer PyTorch/ONNX.')
38
parser.add_argument('--dynamic-size', action='store_true', default=False,
39
                    help='Export model width dynamic width/height. Not recommended for "tf" models with SAME padding.')
40
parser.add_argument('--check-forward', action='store_true', default=False,
41
                    help='Do a full check of torch vs onnx forward after export.')
42
parser.add_argument('-b', '--batch-size', default=1, type=int,
43
                    metavar='N', help='mini-batch size (default: 1)')
44
parser.add_argument('--img-size', default=None, type=int,
45
                    metavar='N', help='Input image dimension, uses model default if empty')
46
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
47
                    help='Override mean pixel value of dataset')
48
parser.add_argument('--std', type=float,  nargs='+', default=None, metavar='STD',
49
                    help='Override std deviation of of dataset')
50
parser.add_argument('--num-classes', type=int, default=1000,
51
                    help='Number classes in dataset')
52
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
53
                    help='path to checkpoint (default: none)')
54
parser.add_argument('--reparam', default=False, action='store_true',
55
                    help='Reparameterize model')
56
parser.add_argument('--training', default=False, action='store_true',
57
                    help='Export in training mode (default is eval)')
58
parser.add_argument('--verbose', default=False, action='store_true',
59
                    help='Extra stdout output')
60
parser.add_argument('--dynamo', default=False, action='store_true',
61
                    help='Use torch dynamo export.')
62

63
def main():
64
    args = parser.parse_args()
65

66
    args.pretrained = True
67
    if args.checkpoint:
68
        args.pretrained = False
69

70
    print("==> Creating PyTorch {} model".format(args.model))
71
    # NOTE exportable=True flag disables autofn/jit scripted activations and uses Conv2dSameExport layers
72
    # for models using SAME padding
73
    model = timm.create_model(
74
        args.model,
75
        num_classes=args.num_classes,
76
        in_chans=3,
77
        pretrained=args.pretrained,
78
        checkpoint_path=args.checkpoint,
79
        exportable=True,
80
    )
81

82
    if args.reparam:
83
        model = reparameterize_model(model)
84

85
    onnx_export(
86
        model,
87
        args.output,
88
        opset=args.opset,
89
        dynamic_size=args.dynamic_size,
90
        aten_fallback=args.aten_fallback,
91
        keep_initializers=args.keep_init,
92
        check_forward=args.check_forward,
93
        training=args.training,
94
        verbose=args.verbose,
95
        use_dynamo=args.dynamo,
96
        input_size=(3, args.img_size, args.img_size),
97
        batch_size=args.batch_size,
98
    )
99

100

101
if __name__ == '__main__':
102
    main()
103

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.