onnxruntime

Форк
0
/
pytorch_export_helpers.py 
131 строка · 5.7 Кб
1
# Copyright (c) Microsoft Corporation. All rights reserved.
2
# Licensed under the MIT License.
3

4
import inspect
5
from collections import abc
6

7
import torch
8

9

10
def _parse_inputs_for_onnx_export(all_input_parameters, inputs, kwargs):
11
    # extracted from https://github.com/microsoft/onnxruntime/blob/239c6ad3f021ff7cc2e6247eb074bd4208dc11e2/orttraining/orttraining/python/training/ortmodule/_io.py#L433
12

13
    def _add_input(name, input):
14
        """Returns number of expanded inputs that _add_input processed"""
15

16
        if input is None:
17
            # Drop all None inputs and return 0.
18
            return 0
19

20
        num_expanded_non_none_inputs = 0
21
        if isinstance(input, abc.Sequence):
22
            # If the input is a sequence (like a list), expand the list so that
23
            # each element of the list is an input by itself.
24
            for i, val in enumerate(input):
25
                # Name each input with the index appended to the original name of the
26
                # argument.
27
                num_expanded_non_none_inputs += _add_input(f"{name}_{i}", val)
28

29
            # Return here since the list by itself is not a valid input.
30
            # All the elements of the list have already been added as inputs individually.
31
            return num_expanded_non_none_inputs
32
        elif isinstance(input, abc.Mapping):
33
            # If the input is a mapping (like a dict), expand the dict so that
34
            # each element of the dict is an input by itself.
35
            for key, val in input.items():
36
                num_expanded_non_none_inputs += _add_input(f"{name}_{key}", val)
37

38
            # Return here since the dict by itself is not a valid input.
39
            # All the elements of the dict have already been added as inputs individually.
40
            return num_expanded_non_none_inputs
41

42
        # InputInfo should contain all the names irrespective of whether they are
43
        # a part of the onnx graph or not.
44
        input_names.append(name)
45

46
        # A single input non none input was processed, return 1
47
        return 1
48

49
    input_names = []
50
    var_positional_idx = 0
51
    num_expanded_non_none_positional_inputs = 0
52

53
    for input_idx, input_parameter in enumerate(all_input_parameters):
54
        if input_parameter.kind == inspect.Parameter.VAR_POSITIONAL:
55
            # VAR_POSITIONAL parameter carries all *args parameters from original forward method
56
            for args_i in range(input_idx, len(inputs)):
57
                name = f"{input_parameter.name}_{var_positional_idx}"
58
                var_positional_idx += 1
59
                inp = inputs[args_i]
60
                num_expanded_non_none_positional_inputs += _add_input(name, inp)
61
        elif (
62
            input_parameter.kind == inspect.Parameter.POSITIONAL_ONLY
63
            or input_parameter.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
64
            or input_parameter.kind == inspect.Parameter.KEYWORD_ONLY
65
        ):
66
            # All positional non-*args and non-**kwargs are processed here
67
            name = input_parameter.name
68
            inp = None
69
            input_idx += var_positional_idx  # noqa: PLW2901
70
            is_positional = True
71
            if input_idx < len(inputs) and inputs[input_idx] is not None:
72
                inp = inputs[input_idx]
73
            elif name in kwargs and kwargs[name] is not None:
74
                inp = kwargs[name]
75
                is_positional = False
76
            num_expanded_non_none_inputs_local = _add_input(name, inp)
77
            if is_positional:
78
                num_expanded_non_none_positional_inputs += num_expanded_non_none_inputs_local
79
        elif input_parameter.kind == inspect.Parameter.VAR_KEYWORD:
80
            # **kwargs is always the last argument of forward()
81
            for name, inp in kwargs.items():
82
                if name not in input_names:
83
                    _add_input(name, inp)
84

85
    return input_names
86

87

88
def _flatten_module_input(names, args, kwargs):
89
    """Flatten args and kwargs in a single tuple of tensors."""
90
    # extracted from https://github.com/microsoft/onnxruntime/blob/239c6ad3f021ff7cc2e6247eb074bd4208dc11e2/orttraining/orttraining/python/training/ortmodule/_io.py#L110
91

92
    def is_primitive_type(value):
93
        return type(value) in {int, bool, float}
94

95
    def to_tensor(value):
96
        return torch.tensor(value)
97

98
    ret = [to_tensor(arg) if is_primitive_type(arg) else arg for arg in args]
99
    ret += [
100
        to_tensor(kwargs[name]) if is_primitive_type(kwargs[name]) else kwargs[name] for name in names if name in kwargs
101
    ]
102

103
    # if kwargs is empty, append an empty dictionary at the end of the sample inputs to make exporter
104
    # happy. This is because the exporter is confused with kwargs and dictionary inputs otherwise.
105
    if not kwargs:
106
        ret.append({})
107

108
    return tuple(ret)
109

110

111
def infer_input_info(module: torch.nn.Module, *inputs, **kwargs):
112
    """
113
    Infer the input names and order from the arguments used to execute a PyTorch module for usage exporting
114
    the model via torch.onnx.export.
115
    Assumes model is on CPU. Use `module.to(torch.device('cpu'))` if it isn't.
116

117
    Example usage:
118
    input_names, inputs_as_tuple = infer_input_info(module, ...)
119
    torch.onnx.export(module, inputs_as_type, 'model.onnx', input_names=input_names, output_names=[...], ...)
120

121
    :param module: Module
122
    :param inputs: Positional inputs
123
    :param kwargs: Keyword argument inputs
124
    :return: Tuple of ordered input names and input values. These can be used directly with torch.onnx.export as the
125
            `input_names` and `inputs` arguments.
126
    """
127
    module_parameters = inspect.signature(module.forward).parameters.values()
128
    input_names = _parse_inputs_for_onnx_export(module_parameters, inputs, kwargs)
129
    inputs_as_tuple = _flatten_module_input(input_names, inputs, kwargs)
130

131
    return input_names, inputs_as_tuple
132

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.