onnxruntime
69 строк · 2.4 Кб
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4import { Graph } from './graph';
5import { OperatorImplementation, OperatorInitialization } from './operators';
6
7export interface OpSet {
8domain: string;
9version: number;
10}
11export declare namespace OpSet {
12/**
13* Domain of an opset, it can be an empty string(default value, represent for ai.onnx), or 'ai.onnx.ml'
14*/
15type Domain = '' | 'ai.onnx.ml' | 'com.microsoft';
16/**
17* A resolve rule consists of 4 or 5 items: opType, opSetDomain, versionSelector, operatorImplementation and
18* operatorInitialization (optional)
19*/
20type ResolveRule =
21| [string, Domain, string, OperatorImplementation<Graph.Node>]
22| [string, Domain, string, OperatorImplementation<unknown>, OperatorInitialization<unknown>];
23}
24
25export function resolveOperator(node: Graph.Node, opsets: readonly OpSet[], rules: readonly OpSet.ResolveRule[]) {
26for (const rule of rules) {
27const opType = rule[0];
28const domain = rule[1];
29const versionSelector = rule[2];
30const opImpl = rule[3];
31const opInit = rule[4];
32
33if (node.opType === opType) {
34// operator type matches
35for (const opset of opsets) {
36// opset '' and 'ai.onnx' are considered the same.
37if (opset.domain === domain || (opset.domain === 'ai.onnx' && domain === '')) {
38// opset domain found
39if (matchSelector(opset.version, versionSelector)) {
40return { opImpl, opInit };
41}
42}
43}
44}
45}
46
47throw new TypeError(
48`cannot resolve operator '${node.opType}' with opsets: ${opsets
49.map((set) => `${set.domain || 'ai.onnx'} v${set.version}`)
50.join(', ')}`,
51);
52}
53
54function matchSelector(version: number, selector: string): boolean {
55if (selector.endsWith('+')) {
56// minimum version match ('7+' expects version>=7)
57const rangeStart = Number.parseInt(selector.substring(0, selector.length - 1), 10);
58return !isNaN(rangeStart) && rangeStart <= version;
59} else if (selector.split('-').length === 2) {
60// range match ('6-8' expects 6<=version<=8)
61const pair = selector.split('-');
62const rangeStart = Number.parseInt(pair[0], 10);
63const rangeEnd = Number.parseInt(pair[1], 10);
64return !isNaN(rangeStart) && !isNaN(rangeEnd) && rangeStart <= version && version <= rangeEnd;
65} else {
66// exact match ('7' expects version===7)
67return Number.parseInt(selector, 10) === version;
68}
69}
70