onnxruntime
84 строки · 2.6 Кб
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4import { flatbuffers } from 'flatbuffers';5
6import { Graph } from './graph';7import { OpSet } from './opset';8import { onnxruntime } from './ort-schema/flatbuffers/ort-generated';9import { onnx } from './ort-schema/protobuf/onnx';10import { LongUtil } from './util';11
12import ortFbs = onnxruntime.experimental.fbs;13
14export class Model {15// empty model16constructor() {}17
18load(buf: Uint8Array, graphInitializer?: Graph.Initializer, isOrtFormat?: boolean): void {19let onnxError: Error | undefined;20if (!isOrtFormat) {21// isOrtFormat === false || isOrtFormat === undefined22try {23this.loadFromOnnxFormat(buf, graphInitializer);24return;25} catch (e) {26if (isOrtFormat !== undefined) {27throw e;28}29onnxError = e;30}31}32
33try {34this.loadFromOrtFormat(buf, graphInitializer);35} catch (e) {36if (isOrtFormat !== undefined) {37throw e;38}39// Tried both formats and failed (when isOrtFormat === undefined)40throw new Error(`Failed to load model as ONNX format: ${onnxError}\nas ORT format: ${e}`);41}42}43
44private loadFromOnnxFormat(buf: Uint8Array, graphInitializer?: Graph.Initializer): void {45const modelProto = onnx.ModelProto.decode(buf);46const irVersion = LongUtil.longToNumber(modelProto.irVersion);47if (irVersion < 3) {48throw new Error('only support ONNX model with IR_VERSION>=3');49}50
51this._opsets = modelProto.opsetImport.map((i) => ({52domain: i.domain as string,53version: LongUtil.longToNumber(i.version!),54}));55
56this._graph = Graph.from(modelProto.graph!, graphInitializer);57}58
59private loadFromOrtFormat(buf: Uint8Array, graphInitializer?: Graph.Initializer): void {60const fb = new flatbuffers.ByteBuffer(buf);61const ortModel = ortFbs.InferenceSession.getRootAsInferenceSession(fb).model()!;62const irVersion = LongUtil.longToNumber(ortModel.irVersion());63if (irVersion < 3) {64throw new Error('only support ONNX model with IR_VERSION>=3');65}66this._opsets = [];67for (let i = 0; i < ortModel.opsetImportLength(); i++) {68const opsetId = ortModel.opsetImport(i)!;69this._opsets.push({ domain: opsetId?.domain() as string, version: LongUtil.longToNumber(opsetId.version()!) });70}71
72this._graph = Graph.from(ortModel.graph()!, graphInitializer);73}74
75private _graph: Graph;76get graph(): Graph {77return this._graph;78}79
80private _opsets: OpSet[];81get opsets(): readonly OpSet[] {82return this._opsets;83}84}
85