onnxruntime
409 строк · 13.8 Кб
1#!/usr/bin/env python
2
3# This file is copied and adapted from https://github.com/onnx/onnx repository.
4# There was no copyright statement on the file at the time of copying.
5
6import argparse
7import os
8import pathlib
9import sys
10from collections import defaultdict
11from typing import Any, Dict, List, Sequence, Set, Text, Tuple # noqa: F401
12
13import numpy as np # type: ignore
14from onnx import AttributeProto, FunctionProto # noqa: F401
15
16import onnxruntime.capi.onnxruntime_pybind11_state as rtpy
17from onnxruntime.capi.onnxruntime_pybind11_state import schemadef # noqa: F401
18from onnxruntime.capi.onnxruntime_pybind11_state.schemadef import OpSchema
19
20ONNX_ML = not bool(os.getenv("ONNX_ML") == "0")
21ONNX_DOMAIN = "onnx"
22ONNX_ML_DOMAIN = "onnx-ml"
23
24if ONNX_ML:
25ext = "-ml.md"
26else:
27ext = ".md"
28
29
30def display_number(v): # type: (int) -> Text
31if OpSchema.is_infinite(v):
32return "∞"
33return str(v)
34
35
36def should_render_domain(domain, domain_filter): # type: (Text) -> bool
37if domain in (ONNX_DOMAIN, ONNX_ML_DOMAIN) or domain == "" or domain == "ai.onnx.ml":
38return False
39
40if domain_filter and domain not in domain_filter:
41return False
42
43return True
44
45
46def format_name_with_domain(domain, schema_name): # type: (Text, Text) -> Text
47if domain:
48return f"{domain}.{schema_name}"
49else:
50return schema_name
51
52
53def format_name_with_version(schema_name, version): # type: (Text, Text) -> Text
54return f"{schema_name}-{version}"
55
56
57def display_attr_type(v): # type: (OpSchema.AttrType) -> Text
58assert isinstance(v, OpSchema.AttrType)
59s = str(v)
60s = s[s.rfind(".") + 1 :].lower()
61if s[-1] == "s":
62s = "list of " + s
63return s
64
65
66def display_domain(domain): # type: (Text) -> Text
67if domain:
68return f"the '{domain}' operator set"
69else:
70return "the default ONNX operator set"
71
72
73def display_domain_short(domain): # type: (Text) -> Text
74if domain:
75return domain
76else:
77return "ai.onnx (default)"
78
79
80def display_version_link(name, version): # type: (Text, int) -> Text
81changelog_md = "Changelog" + ext
82name_with_ver = f"{name}-{version}"
83return f'<a href="{changelog_md}#{name_with_ver}">{name_with_ver}</a>'
84
85
86def display_function_version_link(name, version): # type: (Text, int) -> Text
87changelog_md = "FunctionsChangelog" + ext
88name_with_ver = f"{name}-{version}"
89return f'<a href="{changelog_md}#{name_with_ver}">{name_with_ver}</a>'
90
91
92def get_attribute_value(attr): # type: (AttributeProto) -> Any
93if attr.HasField("f"):
94return attr.f
95elif attr.HasField("i"):
96return attr.i
97elif attr.HasField("s"):
98return attr.s
99elif attr.HasField("t"):
100return attr.t
101elif attr.HasField("g"):
102return attr.g
103elif len(attr.floats):
104return list(attr.floats)
105elif len(attr.ints):
106return list(attr.ints)
107elif len(attr.strings):
108return list(attr.strings)
109elif len(attr.tensors):
110return list(attr.tensors)
111elif len(attr.graphs):
112return list(attr.graphs)
113else:
114raise ValueError(f"Unsupported ONNX attribute: {attr}")
115
116
117def display_schema(schema, versions): # type: (OpSchema, Sequence[OpSchema]) -> Text
118s = ""
119
120# doc
121schemadoc = schema.doc
122if schemadoc:
123s += "\n"
124s += "\n".join(" " + line for line in schemadoc.lstrip().splitlines())
125s += "\n"
126
127# since version
128s += "\n#### Version\n"
129if schema.support_level == OpSchema.SupportType.EXPERIMENTAL:
130s += "\nNo versioning maintained for experimental ops."
131else:
132s += (
133"\nThis version of the operator has been "
134+ ("deprecated" if schema.deprecated else "available")
135+ f" since version {schema.since_version}"
136)
137s += f" of {display_domain(schema.domain)}.\n"
138if len(versions) > 1:
139# TODO: link to the Changelog.md
140s += "\nOther versions of this operator: {}\n".format(
141", ".join(
142format_name_with_version(format_name_with_domain(v.domain, v.name), v.since_version)
143for v in versions[:-1]
144)
145)
146
147# If this schema is deprecated, don't display any of the following sections
148if schema.deprecated:
149return s
150
151# attributes
152attribs = schema.attributes
153if attribs:
154s += "\n#### Attributes\n\n"
155s += "<dl>\n"
156for _, attr in sorted(attribs.items()):
157# option holds either required or default value
158opt = ""
159if attr.required:
160opt = "required"
161elif hasattr(attr, "default_value") and attr.default_value.name:
162default_value = get_attribute_value(attr.default_value)
163
164def format_value(value): # type: (Any) -> Text
165if isinstance(value, float):
166value = np.round(value, 5)
167if isinstance(value, (bytes, bytearray)) and sys.version_info[0] == 3: # noqa: YTT201
168value = value.decode("utf-8")
169return str(value)
170
171if isinstance(default_value, list):
172default_value = [format_value(val) for val in default_value]
173else:
174default_value = format_value(default_value)
175opt = f"default is {default_value}"
176
177s += "<dt><tt>{}</tt> : {}{}</dt>\n".format(
178attr.name, display_attr_type(attr.type), f" ({opt})" if opt else ""
179)
180s += f"<dd>{attr.description}</dd>\n"
181s += "</dl>\n"
182
183# inputs
184s += "\n#### Inputs"
185if schema.min_input != schema.max_input:
186s += f" ({display_number(schema.min_input)} - {display_number(schema.max_input)})"
187s += "\n\n"
188
189inputs = schema.inputs
190if inputs:
191s += "<dl>\n"
192for inp in inputs:
193option_str = ""
194if OpSchema.FormalParameterOption.Optional == inp.option:
195option_str = " (optional)"
196elif OpSchema.FormalParameterOption.Variadic == inp.option:
197if inp.isHomogeneous:
198option_str = " (variadic)"
199else:
200option_str = " (variadic, heterogeneous)"
201s += f"<dt><tt>{inp.name}</tt>{option_str} : {inp.typeStr}</dt>\n"
202s += f"<dd>{inp.description}</dd>\n"
203
204s += "</dl>\n"
205
206# outputs
207s += "\n#### Outputs"
208if schema.min_output != schema.max_output:
209s += f" ({display_number(schema.min_output)} - {display_number(schema.max_output)})"
210s += "\n\n"
211outputs = schema.outputs
212if outputs:
213s += "<dl>\n"
214for output in outputs:
215option_str = ""
216if OpSchema.FormalParameterOption.Optional == output.option:
217option_str = " (optional)"
218elif OpSchema.FormalParameterOption.Variadic == output.option:
219if output.isHomogeneous:
220option_str = " (variadic)"
221else:
222option_str = " (variadic, heterogeneous)"
223s += f"<dt><tt>{output.name}</tt>{option_str} : {output.typeStr}</dt>\n"
224s += f"<dd>{output.description}</dd>\n"
225
226s += "</dl>\n"
227
228# type constraints
229s += "\n#### Type Constraints"
230s += "\n\n"
231typecons = schema.type_constraints
232if typecons:
233s += "<dl>\n"
234for type_constraint in typecons:
235allowed_types = type_constraint.allowed_type_strs
236allowed_type_str = ""
237if len(allowed_types) > 0:
238allowed_type_str = allowed_types[0]
239for allowedType in allowed_types[1:]: # noqa: N806
240allowed_type_str += ", " + allowedType
241s += f"<dt><tt>{type_constraint.type_param_str}</tt> : {allowed_type_str}</dt>\n"
242s += f"<dd>{type_constraint.description}</dd>\n"
243s += "</dl>\n"
244
245return s
246
247
248def display_function(function, versions, domain=ONNX_DOMAIN): # type: (FunctionProto, List[int], Text) -> Text
249s = ""
250
251if domain:
252domain_prefix = f"{ONNX_ML_DOMAIN}."
253else:
254domain_prefix = ""
255
256# doc
257if function.doc_string:
258s += "\n"
259s += "\n".join(" " + line for line in function.doc_string.lstrip().splitlines())
260s += "\n"
261
262# since version
263s += "\n#### Version\n"
264s += f"\nThis version of the function has been available since version {function.since_version}"
265s += f" of {display_domain(domain_prefix)}.\n"
266if len(versions) > 1:
267s += "\nOther versions of this function: {}\n".format(
268", ".join(
269display_function_version_link(domain_prefix + function.name, v)
270for v in versions
271if v != function.since_version
272)
273)
274
275# inputs
276s += "\n#### Inputs"
277s += "\n\n"
278if function.input:
279s += "<dl>\n"
280for input in function.input:
281s += f"<dt>{input}; </dt>\n"
282s += "<br/></dl>\n"
283
284# outputs
285s += "\n#### Outputs"
286s += "\n\n"
287if function.output:
288s += "<dl>\n"
289for output in function.output:
290s += f"<dt>{output}; </dt>\n"
291s += "<br/></dl>\n"
292
293# attributes
294if function.attribute:
295s += "\n#### Attributes\n\n"
296s += "<dl>\n"
297for attr in function.attribute:
298s += f"<dt>{attr};<br/></dt>\n"
299s += "</dl>\n"
300
301return s
302
303
304def support_level_str(level): # type: (OpSchema.SupportType) -> Text
305return "<sub>experimental</sub> " if level == OpSchema.SupportType.EXPERIMENTAL else ""
306
307
308# def function_status_str(status=OperatorStatus.Value("EXPERIMENTAL")): # type: ignore
309# return \
310# "<sub>experimental</sub> " if status == OperatorStatus.Value('EXPERIMENTAL') else "" # type: ignore
311
312
313def main(output_path: str, domain_filter: [str]):
314with open(output_path, "w", newline="", encoding="utf-8") as fout:
315fout.write("## Contrib Operator Schemas\n")
316fout.write(
317"*This file is automatically generated from the registered contrib operator schemas by "
318"[this script](https://github.com/microsoft/onnxruntime/blob/main/tools/python/gen_contrib_doc.py).\n"
319"Do not modify directly.*\n"
320)
321
322# domain -> support level -> name -> [schema]
323index = defaultdict(
324lambda: defaultdict(lambda: defaultdict(list))
325) # type: Dict[Text, Dict[int, Dict[Text, List[OpSchema]]]]
326
327for schema in rtpy.get_all_operator_schema():
328index[schema.domain][int(schema.support_level)][schema.name].append(schema)
329
330fout.write("\n")
331
332# Preprocess the Operator Schemas
333# [(domain, [(support_level, [(schema name, current schema, all versions schemas)])])]
334operator_schemas = (
335list()
336) # type: List[Tuple[Text, List[Tuple[int, List[Tuple[Text, OpSchema, List[OpSchema]]]]]]]
337exsting_ops = set() # type: Set[Text]
338for domain, _supportmap in sorted(index.items()):
339if not should_render_domain(domain, domain_filter):
340continue
341
342processed_supportmap = list()
343for _support, _namemap in sorted(_supportmap.items()):
344processed_namemap = list()
345for n, unsorted_versions in sorted(_namemap.items()):
346versions = sorted(unsorted_versions, key=lambda s: s.since_version)
347schema = versions[-1]
348if schema.name in exsting_ops:
349continue
350exsting_ops.add(schema.name)
351processed_namemap.append((n, schema, versions))
352processed_supportmap.append((_support, processed_namemap))
353operator_schemas.append((domain, processed_supportmap))
354
355# Table of contents
356for domain, supportmap in operator_schemas:
357s = f"* {display_domain_short(domain)}\n"
358fout.write(s)
359
360for _, namemap in supportmap:
361for n, schema, versions in namemap: # noqa: B007
362s = f' * {support_level_str(schema.support_level)}<a href="#{format_name_with_domain(domain, n)}">{format_name_with_domain(domain, n)}</a>\n'
363fout.write(s)
364
365fout.write("\n")
366
367for domain, supportmap in operator_schemas:
368s = f"## {display_domain_short(domain)}\n"
369fout.write(s)
370
371for _, namemap in supportmap:
372for op_type, schema, versions in namemap:
373# op_type
374s = (
375'### {}<a name="{}"></a><a name="{}">**{}**'
376+ (" (deprecated)" if schema.deprecated else "")
377+ "</a>\n"
378).format(
379support_level_str(schema.support_level),
380format_name_with_domain(domain, op_type),
381format_name_with_domain(domain, op_type.lower()),
382format_name_with_domain(domain, op_type),
383)
384
385s += display_schema(schema, versions)
386
387s += "\n\n"
388
389fout.write(s)
390
391
392if __name__ == "__main__":
393parser = argparse.ArgumentParser(description="ONNX Runtime Contrib Operator Documentation Generator")
394parser.add_argument(
395"--domains",
396nargs="+",
397help="Filter to specified domains. " "e.g. `--domains com.microsoft com.microsoft.nchwc`", # noqa: ISC001
398)
399parser.add_argument(
400"--output_path",
401help="output markdown file path",
402type=pathlib.Path,
403required=True,
404default=os.path.join(os.path.dirname(os.path.realpath(__file__)), "ContribOperators.md"),
405)
406args = parser.parse_args()
407output_path = args.output_path.resolve()
408
409main(output_path, args.domains)
410