onnxruntime
270 строк · 8.9 Кб
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4import { resolveBackend, SessionHandlerType } from './backend';5import { ExecutionPlan } from './execution-plan';6import { Graph } from './graph';7import { Profiler } from './instrument';8import { Model } from './model';9import { Operator } from './operators';10import { Tensor } from './tensor';11
12export declare namespace Session {13export interface Config {14backendHint?: string;15profiler?: Profiler.Config;16}17
18export interface Context {19profiler: Readonly<Profiler>;20graphInputTypes?: Tensor.DataType[];21graphInputDims?: Array<readonly number[]>;22}23}
24
25export class Session {26constructor(config: Session.Config = {}) {27this._initialized = false;28this.backendHint = config.backendHint;29this.profiler = Profiler.create(config.profiler);30this.context = { profiler: this.profiler, graphInputTypes: [], graphInputDims: [] };31}32
33get inputNames(): readonly string[] {34return this._model.graph.getInputNames();35}36get outputNames(): readonly string[] {37return this._model.graph.getOutputNames();38}39
40startProfiling() {41this.profiler.start();42}43
44endProfiling() {45this.profiler.stop();46}47
48async loadModel(uri: string): Promise<void>;49async loadModel(buffer: ArrayBuffer, byteOffset?: number, length?: number): Promise<void>;50async loadModel(buffer: Uint8Array): Promise<void>;51async loadModel(arg: string | ArrayBuffer | Uint8Array, byteOffset?: number, length?: number): Promise<void> {52await this.profiler.event('session', 'Session.loadModel', async () => {53// resolve backend and session handler54const backend = await resolveBackend(this.backendHint);55this.sessionHandler = backend.createSessionHandler(this.context);56
57this._model = new Model();58if (typeof arg === 'string') {59const isOrtFormat = arg.endsWith('.ort');60if (typeof process !== 'undefined' && process.versions && process.versions.node) {61// node62const { readFile } = require('node:fs/promises');63const buf = await readFile(arg);64this.initialize(buf, isOrtFormat);65} else {66// browser67const response = await fetch(arg);68const buf = await response.arrayBuffer();69this.initialize(new Uint8Array(buf), isOrtFormat);70}71} else if (!ArrayBuffer.isView(arg)) {72// load model from ArrayBuffer73const arr = new Uint8Array(arg, byteOffset || 0, length || arg.byteLength);74this.initialize(arr);75} else {76// load model from Uint8array77this.initialize(arg);78}79});80}81
82private initialize(modelProtoBlob: Uint8Array, isOrtFormat?: boolean): void {83if (this._initialized) {84throw new Error('already initialized');85}86
87this.profiler.event('session', 'Session.initialize', () => {88// load graph89const graphInitializer = this.sessionHandler.transformGraph90? (this.sessionHandler as Graph.Initializer)91: undefined;92this._model.load(modelProtoBlob, graphInitializer, isOrtFormat);93
94// graph is completely initialzied at this stage , let the interested handlers know95if (this.sessionHandler.onGraphInitialized) {96this.sessionHandler.onGraphInitialized(this._model.graph);97}98// initialize each operator in the graph99this.initializeOps(this._model.graph);100
101// instantiate an ExecutionPlan object to be used by the Session object102this._executionPlan = new ExecutionPlan(this._model.graph, this._ops, this.profiler);103});104
105this._initialized = true;106}107
108async run(inputs: Map<string, Tensor> | Tensor[]): Promise<Map<string, Tensor>> {109if (!this._initialized) {110throw new Error('session not initialized yet');111}112
113return this.profiler.event('session', 'Session.run', async () => {114const inputTensors = this.normalizeAndValidateInputs(inputs);115
116const outputTensors = await this._executionPlan.execute(this.sessionHandler, inputTensors);117
118return this.createOutput(outputTensors);119});120}121
122private normalizeAndValidateInputs(inputs: Map<string, Tensor> | Tensor[]): Tensor[] {123const modelInputNames = this._model.graph.getInputNames();124
125// normalize inputs126// inputs: Tensor[]127if (Array.isArray(inputs)) {128if (inputs.length !== modelInputNames.length) {129throw new Error(`incorrect input array length: expected ${modelInputNames.length} but got ${inputs.length}`);130}131}132// convert map to array133// inputs: Map<string, Tensor>134else {135if (inputs.size !== modelInputNames.length) {136throw new Error(`incorrect input map size: expected ${modelInputNames.length} but got ${inputs.size}`);137}138
139const sortedInputs = new Array<Tensor>(inputs.size);140let sortedInputsIndex = 0;141for (let i = 0; i < modelInputNames.length; ++i) {142const tensor = inputs.get(modelInputNames[i]);143if (!tensor) {144throw new Error(`missing input tensor for: '${name}'`);145}146sortedInputs[sortedInputsIndex++] = tensor;147}148
149inputs = sortedInputs;150}151
152// validate dims requirements153// First session run - graph input data is not cached for the session154if (155!this.context.graphInputTypes ||156this.context.graphInputTypes.length === 0 ||157!this.context.graphInputDims ||158this.context.graphInputDims.length === 0159) {160const modelInputIndices = this._model.graph.getInputIndices();161const modelValues = this._model.graph.getValues();162
163const graphInputDims = new Array<readonly number[]>(modelInputIndices.length);164
165for (let i = 0; i < modelInputIndices.length; ++i) {166const graphInput = modelValues[modelInputIndices[i]];167graphInputDims[i] = graphInput.type!.shape.dims;168
169// cached for second and subsequent runs.170// Some parts of the framework works on the assumption that the graph and types and shapes are static171this.context.graphInputTypes!.push(graphInput.type!.tensorType);172this.context.graphInputDims!.push(inputs[i].dims);173}174
175this.validateInputTensorDims(graphInputDims, inputs, true);176}177
178// Second and subsequent session runs - graph input data is cached for the session179else {180this.validateInputTensorDims(this.context.graphInputDims, inputs, false);181}182
183// validate types requirement184this.validateInputTensorTypes(this.context.graphInputTypes!, inputs);185
186return inputs;187}188
189private validateInputTensorTypes(graphInputTypes: Tensor.DataType[], givenInputs: Tensor[]) {190for (let i = 0; i < givenInputs.length; i++) {191const expectedType = graphInputTypes[i];192const actualType = givenInputs[i].type;193if (expectedType !== actualType) {194throw new Error(`input tensor[${i}] check failed: expected type '${expectedType}' but got ${actualType}`);195}196}197}198
199private validateInputTensorDims(200graphInputDims: Array<readonly number[]>,201givenInputs: Tensor[],202noneDimSupported: boolean,203) {204for (let i = 0; i < givenInputs.length; i++) {205const expectedDims = graphInputDims[i];206const actualDims = givenInputs[i].dims;207if (!this.compareTensorDims(expectedDims, actualDims, noneDimSupported)) {208throw new Error(209`input tensor[${i}] check failed: expected shape '[${expectedDims.join(',')}]' but got [${actualDims.join(210',',211)}]`,212);213}214}215}216
217private compareTensorDims(218expectedDims: readonly number[],219actualDims: readonly number[],220noneDimSupported: boolean,221): boolean {222if (expectedDims.length !== actualDims.length) {223return false;224}225
226for (let i = 0; i < expectedDims.length; ++i) {227if (expectedDims[i] !== actualDims[i] && (!noneDimSupported || expectedDims[i] !== 0)) {228// data shape mis-match AND not a 'None' dimension.229return false;230}231}232
233return true;234}235
236private createOutput(outputTensors: Tensor[]): Map<string, Tensor> {237const modelOutputNames = this._model.graph.getOutputNames();238if (outputTensors.length !== modelOutputNames.length) {239throw new Error('expected number of outputs do not match number of generated outputs');240}241
242const output = new Map<string, Tensor>();243for (let i = 0; i < modelOutputNames.length; ++i) {244output.set(modelOutputNames[i], outputTensors[i]);245}246
247return output;248}249
250private initializeOps(graph: Graph): void {251const nodes = graph.getNodes();252this._ops = new Array(nodes.length);253
254for (let i = 0; i < nodes.length; i++) {255this._ops[i] = this.sessionHandler.resolve(nodes[i], this._model.opsets, graph);256}257}258
259private _model: Model;260private _initialized: boolean;261
262private _ops: Operator[];263private _executionPlan: ExecutionPlan;264
265private backendHint?: string;266
267private sessionHandler: SessionHandlerType;268private context: Session.Context;269private profiler: Readonly<Profiler>;270}
271