onnxruntime

Форк
0
84 строки · 2.6 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { flatbuffers } from 'flatbuffers';
5

6
import { Graph } from './graph';
7
import { OpSet } from './opset';
8
import { onnxruntime } from './ort-schema/flatbuffers/ort-generated';
9
import { onnx } from './ort-schema/protobuf/onnx';
10
import { LongUtil } from './util';
11

12
import ortFbs = onnxruntime.experimental.fbs;
13

14
export class Model {
15
  // empty model
16
  constructor() {}
17

18
  load(buf: Uint8Array, graphInitializer?: Graph.Initializer, isOrtFormat?: boolean): void {
19
    let onnxError: Error | undefined;
20
    if (!isOrtFormat) {
21
      // isOrtFormat === false || isOrtFormat === undefined
22
      try {
23
        this.loadFromOnnxFormat(buf, graphInitializer);
24
        return;
25
      } catch (e) {
26
        if (isOrtFormat !== undefined) {
27
          throw e;
28
        }
29
        onnxError = e;
30
      }
31
    }
32

33
    try {
34
      this.loadFromOrtFormat(buf, graphInitializer);
35
    } catch (e) {
36
      if (isOrtFormat !== undefined) {
37
        throw e;
38
      }
39
      // Tried both formats and failed (when isOrtFormat === undefined)
40
      throw new Error(`Failed to load model as ONNX format: ${onnxError}\nas ORT format: ${e}`);
41
    }
42
  }
43

44
  private loadFromOnnxFormat(buf: Uint8Array, graphInitializer?: Graph.Initializer): void {
45
    const modelProto = onnx.ModelProto.decode(buf);
46
    const irVersion = LongUtil.longToNumber(modelProto.irVersion);
47
    if (irVersion < 3) {
48
      throw new Error('only support ONNX model with IR_VERSION>=3');
49
    }
50

51
    this._opsets = modelProto.opsetImport.map((i) => ({
52
      domain: i.domain as string,
53
      version: LongUtil.longToNumber(i.version!),
54
    }));
55

56
    this._graph = Graph.from(modelProto.graph!, graphInitializer);
57
  }
58

59
  private loadFromOrtFormat(buf: Uint8Array, graphInitializer?: Graph.Initializer): void {
60
    const fb = new flatbuffers.ByteBuffer(buf);
61
    const ortModel = ortFbs.InferenceSession.getRootAsInferenceSession(fb).model()!;
62
    const irVersion = LongUtil.longToNumber(ortModel.irVersion());
63
    if (irVersion < 3) {
64
      throw new Error('only support ONNX model with IR_VERSION>=3');
65
    }
66
    this._opsets = [];
67
    for (let i = 0; i < ortModel.opsetImportLength(); i++) {
68
      const opsetId = ortModel.opsetImport(i)!;
69
      this._opsets.push({ domain: opsetId?.domain() as string, version: LongUtil.longToNumber(opsetId.version()!) });
70
    }
71

72
    this._graph = Graph.from(ortModel.graph()!, graphInitializer);
73
  }
74

75
  private _graph: Graph;
76
  get graph(): Graph {
77
    return this._graph;
78
  }
79

80
  private _opsets: OpSet[];
81
  get opsets(): readonly OpSet[] {
82
    return this._opsets;
83
  }
84
}
85

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.