onnxruntime
131 строка · 5.7 Кб
1# Copyright (c) Microsoft Corporation. All rights reserved.
2# Licensed under the MIT License.
3
4import inspect5from collections import abc6
7import torch8
9
10def _parse_inputs_for_onnx_export(all_input_parameters, inputs, kwargs):11# extracted from https://github.com/microsoft/onnxruntime/blob/239c6ad3f021ff7cc2e6247eb074bd4208dc11e2/orttraining/orttraining/python/training/ortmodule/_io.py#L43312
13def _add_input(name, input):14"""Returns number of expanded inputs that _add_input processed"""15
16if input is None:17# Drop all None inputs and return 0.18return 019
20num_expanded_non_none_inputs = 021if isinstance(input, abc.Sequence):22# If the input is a sequence (like a list), expand the list so that23# each element of the list is an input by itself.24for i, val in enumerate(input):25# Name each input with the index appended to the original name of the26# argument.27num_expanded_non_none_inputs += _add_input(f"{name}_{i}", val)28
29# Return here since the list by itself is not a valid input.30# All the elements of the list have already been added as inputs individually.31return num_expanded_non_none_inputs32elif isinstance(input, abc.Mapping):33# If the input is a mapping (like a dict), expand the dict so that34# each element of the dict is an input by itself.35for key, val in input.items():36num_expanded_non_none_inputs += _add_input(f"{name}_{key}", val)37
38# Return here since the dict by itself is not a valid input.39# All the elements of the dict have already been added as inputs individually.40return num_expanded_non_none_inputs41
42# InputInfo should contain all the names irrespective of whether they are43# a part of the onnx graph or not.44input_names.append(name)45
46# A single input non none input was processed, return 147return 148
49input_names = []50var_positional_idx = 051num_expanded_non_none_positional_inputs = 052
53for input_idx, input_parameter in enumerate(all_input_parameters):54if input_parameter.kind == inspect.Parameter.VAR_POSITIONAL:55# VAR_POSITIONAL parameter carries all *args parameters from original forward method56for args_i in range(input_idx, len(inputs)):57name = f"{input_parameter.name}_{var_positional_idx}"58var_positional_idx += 159inp = inputs[args_i]60num_expanded_non_none_positional_inputs += _add_input(name, inp)61elif (62input_parameter.kind == inspect.Parameter.POSITIONAL_ONLY63or input_parameter.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD64or input_parameter.kind == inspect.Parameter.KEYWORD_ONLY65):66# All positional non-*args and non-**kwargs are processed here67name = input_parameter.name68inp = None69input_idx += var_positional_idx # noqa: PLW290170is_positional = True71if input_idx < len(inputs) and inputs[input_idx] is not None:72inp = inputs[input_idx]73elif name in kwargs and kwargs[name] is not None:74inp = kwargs[name]75is_positional = False76num_expanded_non_none_inputs_local = _add_input(name, inp)77if is_positional:78num_expanded_non_none_positional_inputs += num_expanded_non_none_inputs_local79elif input_parameter.kind == inspect.Parameter.VAR_KEYWORD:80# **kwargs is always the last argument of forward()81for name, inp in kwargs.items():82if name not in input_names:83_add_input(name, inp)84
85return input_names86
87
88def _flatten_module_input(names, args, kwargs):89"""Flatten args and kwargs in a single tuple of tensors."""90# extracted from https://github.com/microsoft/onnxruntime/blob/239c6ad3f021ff7cc2e6247eb074bd4208dc11e2/orttraining/orttraining/python/training/ortmodule/_io.py#L11091
92def is_primitive_type(value):93return type(value) in {int, bool, float}94
95def to_tensor(value):96return torch.tensor(value)97
98ret = [to_tensor(arg) if is_primitive_type(arg) else arg for arg in args]99ret += [100to_tensor(kwargs[name]) if is_primitive_type(kwargs[name]) else kwargs[name] for name in names if name in kwargs101]102
103# if kwargs is empty, append an empty dictionary at the end of the sample inputs to make exporter104# happy. This is because the exporter is confused with kwargs and dictionary inputs otherwise.105if not kwargs:106ret.append({})107
108return tuple(ret)109
110
111def infer_input_info(module: torch.nn.Module, *inputs, **kwargs):112"""113Infer the input names and order from the arguments used to execute a PyTorch module for usage exporting
114the model via torch.onnx.export.
115Assumes model is on CPU. Use `module.to(torch.device('cpu'))` if it isn't.
116
117Example usage:
118input_names, inputs_as_tuple = infer_input_info(module, ...)
119torch.onnx.export(module, inputs_as_type, 'model.onnx', input_names=input_names, output_names=[...], ...)
120
121:param module: Module
122:param inputs: Positional inputs
123:param kwargs: Keyword argument inputs
124:return: Tuple of ordered input names and input values. These can be used directly with torch.onnx.export as the
125`input_names` and `inputs` arguments.
126"""
127module_parameters = inspect.signature(module.forward).parameters.values()128input_names = _parse_inputs_for_onnx_export(module_parameters, inputs, kwargs)129inputs_as_tuple = _flatten_module_input(input_names, inputs, kwargs)130
131return input_names, inputs_as_tuple132