onnxruntime

Форк
0
/
gen_contrib_doc.py 
409 строк · 13.8 Кб
1
#!/usr/bin/env python
2

3
# This file is copied and adapted from https://github.com/onnx/onnx repository.
4
# There was no copyright statement on the file at the time of copying.
5

6
import argparse
7
import os
8
import pathlib
9
import sys
10
from collections import defaultdict
11
from typing import Any, Dict, List, Sequence, Set, Text, Tuple  # noqa: F401
12

13
import numpy as np  # type: ignore
14
from onnx import AttributeProto, FunctionProto  # noqa: F401
15

16
import onnxruntime.capi.onnxruntime_pybind11_state as rtpy
17
from onnxruntime.capi.onnxruntime_pybind11_state import schemadef  # noqa: F401
18
from onnxruntime.capi.onnxruntime_pybind11_state.schemadef import OpSchema
19

20
ONNX_ML = not bool(os.getenv("ONNX_ML") == "0")
21
ONNX_DOMAIN = "onnx"
22
ONNX_ML_DOMAIN = "onnx-ml"
23

24
if ONNX_ML:
25
    ext = "-ml.md"
26
else:
27
    ext = ".md"
28

29

30
def display_number(v):  # type: (int) -> Text
31
    if OpSchema.is_infinite(v):
32
        return "∞"
33
    return str(v)
34

35

36
def should_render_domain(domain, domain_filter):  # type: (Text) -> bool
37
    if domain in (ONNX_DOMAIN, ONNX_ML_DOMAIN) or domain == "" or domain == "ai.onnx.ml":
38
        return False
39

40
    if domain_filter and domain not in domain_filter:
41
        return False
42

43
    return True
44

45

46
def format_name_with_domain(domain, schema_name):  # type: (Text, Text) -> Text
47
    if domain:
48
        return f"{domain}.{schema_name}"
49
    else:
50
        return schema_name
51

52

53
def format_name_with_version(schema_name, version):  # type: (Text, Text) -> Text
54
    return f"{schema_name}-{version}"
55

56

57
def display_attr_type(v):  # type: (OpSchema.AttrType) -> Text
58
    assert isinstance(v, OpSchema.AttrType)
59
    s = str(v)
60
    s = s[s.rfind(".") + 1 :].lower()
61
    if s[-1] == "s":
62
        s = "list of " + s
63
    return s
64

65

66
def display_domain(domain):  # type: (Text) -> Text
67
    if domain:
68
        return f"the '{domain}' operator set"
69
    else:
70
        return "the default ONNX operator set"
71

72

73
def display_domain_short(domain):  # type: (Text) -> Text
74
    if domain:
75
        return domain
76
    else:
77
        return "ai.onnx (default)"
78

79

80
def display_version_link(name, version):  # type: (Text, int) -> Text
81
    changelog_md = "Changelog" + ext
82
    name_with_ver = f"{name}-{version}"
83
    return f'<a href="{changelog_md}#{name_with_ver}">{name_with_ver}</a>'
84

85

86
def display_function_version_link(name, version):  # type: (Text, int) -> Text
87
    changelog_md = "FunctionsChangelog" + ext
88
    name_with_ver = f"{name}-{version}"
89
    return f'<a href="{changelog_md}#{name_with_ver}">{name_with_ver}</a>'
90

91

92
def get_attribute_value(attr):  # type: (AttributeProto) -> Any
93
    if attr.HasField("f"):
94
        return attr.f
95
    elif attr.HasField("i"):
96
        return attr.i
97
    elif attr.HasField("s"):
98
        return attr.s
99
    elif attr.HasField("t"):
100
        return attr.t
101
    elif attr.HasField("g"):
102
        return attr.g
103
    elif len(attr.floats):
104
        return list(attr.floats)
105
    elif len(attr.ints):
106
        return list(attr.ints)
107
    elif len(attr.strings):
108
        return list(attr.strings)
109
    elif len(attr.tensors):
110
        return list(attr.tensors)
111
    elif len(attr.graphs):
112
        return list(attr.graphs)
113
    else:
114
        raise ValueError(f"Unsupported ONNX attribute: {attr}")
115

116

117
def display_schema(schema, versions):  # type: (OpSchema, Sequence[OpSchema]) -> Text
118
    s = ""
119

120
    # doc
121
    schemadoc = schema.doc
122
    if schemadoc:
123
        s += "\n"
124
        s += "\n".join("  " + line for line in schemadoc.lstrip().splitlines())
125
        s += "\n"
126

127
    # since version
128
    s += "\n#### Version\n"
129
    if schema.support_level == OpSchema.SupportType.EXPERIMENTAL:
130
        s += "\nNo versioning maintained for experimental ops."
131
    else:
132
        s += (
133
            "\nThis version of the operator has been "
134
            + ("deprecated" if schema.deprecated else "available")
135
            + f" since version {schema.since_version}"
136
        )
137
        s += f" of {display_domain(schema.domain)}.\n"
138
        if len(versions) > 1:
139
            # TODO: link to the Changelog.md
140
            s += "\nOther versions of this operator: {}\n".format(
141
                ", ".join(
142
                    format_name_with_version(format_name_with_domain(v.domain, v.name), v.since_version)
143
                    for v in versions[:-1]
144
                )
145
            )
146

147
    # If this schema is deprecated, don't display any of the following sections
148
    if schema.deprecated:
149
        return s
150

151
    # attributes
152
    attribs = schema.attributes
153
    if attribs:
154
        s += "\n#### Attributes\n\n"
155
        s += "<dl>\n"
156
        for _, attr in sorted(attribs.items()):
157
            # option holds either required or default value
158
            opt = ""
159
            if attr.required:
160
                opt = "required"
161
            elif hasattr(attr, "default_value") and attr.default_value.name:
162
                default_value = get_attribute_value(attr.default_value)
163

164
                def format_value(value):  # type: (Any) -> Text
165
                    if isinstance(value, float):
166
                        value = np.round(value, 5)
167
                    if isinstance(value, (bytes, bytearray)) and sys.version_info[0] == 3:  # noqa: YTT201
168
                        value = value.decode("utf-8")
169
                    return str(value)
170

171
                if isinstance(default_value, list):
172
                    default_value = [format_value(val) for val in default_value]
173
                else:
174
                    default_value = format_value(default_value)
175
                opt = f"default is {default_value}"
176

177
            s += "<dt><tt>{}</tt> : {}{}</dt>\n".format(
178
                attr.name, display_attr_type(attr.type), f" ({opt})" if opt else ""
179
            )
180
            s += f"<dd>{attr.description}</dd>\n"
181
        s += "</dl>\n"
182

183
    # inputs
184
    s += "\n#### Inputs"
185
    if schema.min_input != schema.max_input:
186
        s += f" ({display_number(schema.min_input)} - {display_number(schema.max_input)})"
187
    s += "\n\n"
188

189
    inputs = schema.inputs
190
    if inputs:
191
        s += "<dl>\n"
192
        for inp in inputs:
193
            option_str = ""
194
            if OpSchema.FormalParameterOption.Optional == inp.option:
195
                option_str = " (optional)"
196
            elif OpSchema.FormalParameterOption.Variadic == inp.option:
197
                if inp.isHomogeneous:
198
                    option_str = " (variadic)"
199
                else:
200
                    option_str = " (variadic, heterogeneous)"
201
            s += f"<dt><tt>{inp.name}</tt>{option_str} : {inp.typeStr}</dt>\n"
202
            s += f"<dd>{inp.description}</dd>\n"
203

204
    s += "</dl>\n"
205

206
    # outputs
207
    s += "\n#### Outputs"
208
    if schema.min_output != schema.max_output:
209
        s += f" ({display_number(schema.min_output)} - {display_number(schema.max_output)})"
210
    s += "\n\n"
211
    outputs = schema.outputs
212
    if outputs:
213
        s += "<dl>\n"
214
        for output in outputs:
215
            option_str = ""
216
            if OpSchema.FormalParameterOption.Optional == output.option:
217
                option_str = " (optional)"
218
            elif OpSchema.FormalParameterOption.Variadic == output.option:
219
                if output.isHomogeneous:
220
                    option_str = " (variadic)"
221
                else:
222
                    option_str = " (variadic, heterogeneous)"
223
            s += f"<dt><tt>{output.name}</tt>{option_str} : {output.typeStr}</dt>\n"
224
            s += f"<dd>{output.description}</dd>\n"
225

226
    s += "</dl>\n"
227

228
    # type constraints
229
    s += "\n#### Type Constraints"
230
    s += "\n\n"
231
    typecons = schema.type_constraints
232
    if typecons:
233
        s += "<dl>\n"
234
        for type_constraint in typecons:
235
            allowed_types = type_constraint.allowed_type_strs
236
            allowed_type_str = ""
237
            if len(allowed_types) > 0:
238
                allowed_type_str = allowed_types[0]
239
            for allowedType in allowed_types[1:]:  # noqa: N806
240
                allowed_type_str += ", " + allowedType
241
            s += f"<dt><tt>{type_constraint.type_param_str}</tt> : {allowed_type_str}</dt>\n"
242
            s += f"<dd>{type_constraint.description}</dd>\n"
243
        s += "</dl>\n"
244

245
    return s
246

247

248
def display_function(function, versions, domain=ONNX_DOMAIN):  # type: (FunctionProto, List[int], Text) -> Text
249
    s = ""
250

251
    if domain:
252
        domain_prefix = f"{ONNX_ML_DOMAIN}."
253
    else:
254
        domain_prefix = ""
255

256
    # doc
257
    if function.doc_string:
258
        s += "\n"
259
        s += "\n".join("  " + line for line in function.doc_string.lstrip().splitlines())
260
        s += "\n"
261

262
    # since version
263
    s += "\n#### Version\n"
264
    s += f"\nThis version of the function has been available since version {function.since_version}"
265
    s += f" of {display_domain(domain_prefix)}.\n"
266
    if len(versions) > 1:
267
        s += "\nOther versions of this function: {}\n".format(
268
            ", ".join(
269
                display_function_version_link(domain_prefix + function.name, v)
270
                for v in versions
271
                if v != function.since_version
272
            )
273
        )
274

275
    # inputs
276
    s += "\n#### Inputs"
277
    s += "\n\n"
278
    if function.input:
279
        s += "<dl>\n"
280
        for input in function.input:
281
            s += f"<dt>{input}; </dt>\n"
282
        s += "<br/></dl>\n"
283

284
    # outputs
285
    s += "\n#### Outputs"
286
    s += "\n\n"
287
    if function.output:
288
        s += "<dl>\n"
289
        for output in function.output:
290
            s += f"<dt>{output}; </dt>\n"
291
        s += "<br/></dl>\n"
292

293
        # attributes
294
    if function.attribute:
295
        s += "\n#### Attributes\n\n"
296
        s += "<dl>\n"
297
        for attr in function.attribute:
298
            s += f"<dt>{attr};<br/></dt>\n"
299
        s += "</dl>\n"
300

301
    return s
302

303

304
def support_level_str(level):  # type: (OpSchema.SupportType) -> Text
305
    return "<sub>experimental</sub> " if level == OpSchema.SupportType.EXPERIMENTAL else ""
306

307

308
# def function_status_str(status=OperatorStatus.Value("EXPERIMENTAL")):  # type: ignore
309
#     return \
310
#         "<sub>experimental</sub> " if status == OperatorStatus.Value('EXPERIMENTAL') else ""  # type: ignore
311

312

313
def main(output_path: str, domain_filter: [str]):
314
    with open(output_path, "w", newline="", encoding="utf-8") as fout:
315
        fout.write("## Contrib Operator Schemas\n")
316
        fout.write(
317
            "*This file is automatically generated from the registered contrib operator schemas by "
318
            "[this script](https://github.com/microsoft/onnxruntime/blob/main/tools/python/gen_contrib_doc.py).\n"
319
            "Do not modify directly.*\n"
320
        )
321

322
        # domain -> support level -> name -> [schema]
323
        index = defaultdict(
324
            lambda: defaultdict(lambda: defaultdict(list))
325
        )  # type: Dict[Text, Dict[int, Dict[Text, List[OpSchema]]]]
326

327
        for schema in rtpy.get_all_operator_schema():
328
            index[schema.domain][int(schema.support_level)][schema.name].append(schema)
329

330
        fout.write("\n")
331

332
        # Preprocess the Operator Schemas
333
        # [(domain, [(support_level, [(schema name, current schema, all versions schemas)])])]
334
        operator_schemas = (
335
            list()
336
        )  # type: List[Tuple[Text, List[Tuple[int, List[Tuple[Text, OpSchema, List[OpSchema]]]]]]]
337
        exsting_ops = set()  # type: Set[Text]
338
        for domain, _supportmap in sorted(index.items()):
339
            if not should_render_domain(domain, domain_filter):
340
                continue
341

342
            processed_supportmap = list()
343
            for _support, _namemap in sorted(_supportmap.items()):
344
                processed_namemap = list()
345
                for n, unsorted_versions in sorted(_namemap.items()):
346
                    versions = sorted(unsorted_versions, key=lambda s: s.since_version)
347
                    schema = versions[-1]
348
                    if schema.name in exsting_ops:
349
                        continue
350
                    exsting_ops.add(schema.name)
351
                    processed_namemap.append((n, schema, versions))
352
                processed_supportmap.append((_support, processed_namemap))
353
            operator_schemas.append((domain, processed_supportmap))
354

355
        # Table of contents
356
        for domain, supportmap in operator_schemas:
357
            s = f"* {display_domain_short(domain)}\n"
358
            fout.write(s)
359

360
            for _, namemap in supportmap:
361
                for n, schema, versions in namemap:  # noqa: B007
362
                    s = f'  * {support_level_str(schema.support_level)}<a href="#{format_name_with_domain(domain, n)}">{format_name_with_domain(domain, n)}</a>\n'
363
                    fout.write(s)
364

365
        fout.write("\n")
366

367
        for domain, supportmap in operator_schemas:
368
            s = f"## {display_domain_short(domain)}\n"
369
            fout.write(s)
370

371
            for _, namemap in supportmap:
372
                for op_type, schema, versions in namemap:
373
                    # op_type
374
                    s = (
375
                        '### {}<a name="{}"></a><a name="{}">**{}**'
376
                        + (" (deprecated)" if schema.deprecated else "")
377
                        + "</a>\n"
378
                    ).format(
379
                        support_level_str(schema.support_level),
380
                        format_name_with_domain(domain, op_type),
381
                        format_name_with_domain(domain, op_type.lower()),
382
                        format_name_with_domain(domain, op_type),
383
                    )
384

385
                    s += display_schema(schema, versions)
386

387
                    s += "\n\n"
388

389
                    fout.write(s)
390

391

392
if __name__ == "__main__":
393
    parser = argparse.ArgumentParser(description="ONNX Runtime Contrib Operator Documentation Generator")
394
    parser.add_argument(
395
        "--domains",
396
        nargs="+",
397
        help="Filter to specified domains. " "e.g. `--domains com.microsoft com.microsoft.nchwc`",  # noqa: ISC001
398
    )
399
    parser.add_argument(
400
        "--output_path",
401
        help="output markdown file path",
402
        type=pathlib.Path,
403
        required=True,
404
        default=os.path.join(os.path.dirname(os.path.realpath(__file__)), "ContribOperators.md"),
405
    )
406
    args = parser.parse_args()
407
    output_path = args.output_path.resolve()
408

409
    main(output_path, args.domains)
410

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.