onnxruntime

Форк
0
69 строк · 2.4 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { Graph } from './graph';
5
import { OperatorImplementation, OperatorInitialization } from './operators';
6

7
export interface OpSet {
8
  domain: string;
9
  version: number;
10
}
11
export declare namespace OpSet {
12
  /**
13
   * Domain of an opset, it can be an empty string(default value, represent for ai.onnx), or 'ai.onnx.ml'
14
   */
15
  type Domain = '' | 'ai.onnx.ml' | 'com.microsoft';
16
  /**
17
   * A resolve rule consists of 4 or 5 items: opType, opSetDomain, versionSelector, operatorImplementation and
18
   * operatorInitialization (optional)
19
   */
20
  type ResolveRule =
21
    | [string, Domain, string, OperatorImplementation<Graph.Node>]
22
    | [string, Domain, string, OperatorImplementation<unknown>, OperatorInitialization<unknown>];
23
}
24

25
export function resolveOperator(node: Graph.Node, opsets: readonly OpSet[], rules: readonly OpSet.ResolveRule[]) {
26
  for (const rule of rules) {
27
    const opType = rule[0];
28
    const domain = rule[1];
29
    const versionSelector = rule[2];
30
    const opImpl = rule[3];
31
    const opInit = rule[4];
32

33
    if (node.opType === opType) {
34
      // operator type matches
35
      for (const opset of opsets) {
36
        // opset '' and 'ai.onnx' are considered the same.
37
        if (opset.domain === domain || (opset.domain === 'ai.onnx' && domain === '')) {
38
          // opset domain found
39
          if (matchSelector(opset.version, versionSelector)) {
40
            return { opImpl, opInit };
41
          }
42
        }
43
      }
44
    }
45
  }
46

47
  throw new TypeError(
48
    `cannot resolve operator '${node.opType}' with opsets: ${opsets
49
      .map((set) => `${set.domain || 'ai.onnx'} v${set.version}`)
50
      .join(', ')}`,
51
  );
52
}
53

54
function matchSelector(version: number, selector: string): boolean {
55
  if (selector.endsWith('+')) {
56
    // minimum version match ('7+' expects version>=7)
57
    const rangeStart = Number.parseInt(selector.substring(0, selector.length - 1), 10);
58
    return !isNaN(rangeStart) && rangeStart <= version;
59
  } else if (selector.split('-').length === 2) {
60
    // range match ('6-8' expects 6<=version<=8)
61
    const pair = selector.split('-');
62
    const rangeStart = Number.parseInt(pair[0], 10);
63
    const rangeEnd = Number.parseInt(pair[1], 10);
64
    return !isNaN(rangeStart) && !isNaN(rangeEnd) && rangeStart <= version && version <= rangeEnd;
65
  } else {
66
    // exact match ('7' expects version===7)
67
    return Number.parseInt(selector, 10) === version;
68
  }
69
}
70

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.