onnxruntime

Форк
0
/
session.ts 
270 строк · 8.9 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { resolveBackend, SessionHandlerType } from './backend';
5
import { ExecutionPlan } from './execution-plan';
6
import { Graph } from './graph';
7
import { Profiler } from './instrument';
8
import { Model } from './model';
9
import { Operator } from './operators';
10
import { Tensor } from './tensor';
11

12
export declare namespace Session {
13
  export interface Config {
14
    backendHint?: string;
15
    profiler?: Profiler.Config;
16
  }
17

18
  export interface Context {
19
    profiler: Readonly<Profiler>;
20
    graphInputTypes?: Tensor.DataType[];
21
    graphInputDims?: Array<readonly number[]>;
22
  }
23
}
24

25
export class Session {
26
  constructor(config: Session.Config = {}) {
27
    this._initialized = false;
28
    this.backendHint = config.backendHint;
29
    this.profiler = Profiler.create(config.profiler);
30
    this.context = { profiler: this.profiler, graphInputTypes: [], graphInputDims: [] };
31
  }
32

33
  get inputNames(): readonly string[] {
34
    return this._model.graph.getInputNames();
35
  }
36
  get outputNames(): readonly string[] {
37
    return this._model.graph.getOutputNames();
38
  }
39

40
  startProfiling() {
41
    this.profiler.start();
42
  }
43

44
  endProfiling() {
45
    this.profiler.stop();
46
  }
47

48
  async loadModel(uri: string): Promise<void>;
49
  async loadModel(buffer: ArrayBuffer, byteOffset?: number, length?: number): Promise<void>;
50
  async loadModel(buffer: Uint8Array): Promise<void>;
51
  async loadModel(arg: string | ArrayBuffer | Uint8Array, byteOffset?: number, length?: number): Promise<void> {
52
    await this.profiler.event('session', 'Session.loadModel', async () => {
53
      // resolve backend and session handler
54
      const backend = await resolveBackend(this.backendHint);
55
      this.sessionHandler = backend.createSessionHandler(this.context);
56

57
      this._model = new Model();
58
      if (typeof arg === 'string') {
59
        const isOrtFormat = arg.endsWith('.ort');
60
        if (typeof process !== 'undefined' && process.versions && process.versions.node) {
61
          // node
62
          const { readFile } = require('node:fs/promises');
63
          const buf = await readFile(arg);
64
          this.initialize(buf, isOrtFormat);
65
        } else {
66
          // browser
67
          const response = await fetch(arg);
68
          const buf = await response.arrayBuffer();
69
          this.initialize(new Uint8Array(buf), isOrtFormat);
70
        }
71
      } else if (!ArrayBuffer.isView(arg)) {
72
        // load model from ArrayBuffer
73
        const arr = new Uint8Array(arg, byteOffset || 0, length || arg.byteLength);
74
        this.initialize(arr);
75
      } else {
76
        // load model from Uint8array
77
        this.initialize(arg);
78
      }
79
    });
80
  }
81

82
  private initialize(modelProtoBlob: Uint8Array, isOrtFormat?: boolean): void {
83
    if (this._initialized) {
84
      throw new Error('already initialized');
85
    }
86

87
    this.profiler.event('session', 'Session.initialize', () => {
88
      // load graph
89
      const graphInitializer = this.sessionHandler.transformGraph
90
        ? (this.sessionHandler as Graph.Initializer)
91
        : undefined;
92
      this._model.load(modelProtoBlob, graphInitializer, isOrtFormat);
93

94
      // graph is completely initialzied at this stage , let the interested handlers know
95
      if (this.sessionHandler.onGraphInitialized) {
96
        this.sessionHandler.onGraphInitialized(this._model.graph);
97
      }
98
      // initialize each operator in the graph
99
      this.initializeOps(this._model.graph);
100

101
      // instantiate an ExecutionPlan object to be used by the Session object
102
      this._executionPlan = new ExecutionPlan(this._model.graph, this._ops, this.profiler);
103
    });
104

105
    this._initialized = true;
106
  }
107

108
  async run(inputs: Map<string, Tensor> | Tensor[]): Promise<Map<string, Tensor>> {
109
    if (!this._initialized) {
110
      throw new Error('session not initialized yet');
111
    }
112

113
    return this.profiler.event('session', 'Session.run', async () => {
114
      const inputTensors = this.normalizeAndValidateInputs(inputs);
115

116
      const outputTensors = await this._executionPlan.execute(this.sessionHandler, inputTensors);
117

118
      return this.createOutput(outputTensors);
119
    });
120
  }
121

122
  private normalizeAndValidateInputs(inputs: Map<string, Tensor> | Tensor[]): Tensor[] {
123
    const modelInputNames = this._model.graph.getInputNames();
124

125
    // normalize inputs
126
    // inputs: Tensor[]
127
    if (Array.isArray(inputs)) {
128
      if (inputs.length !== modelInputNames.length) {
129
        throw new Error(`incorrect input array length: expected ${modelInputNames.length} but got ${inputs.length}`);
130
      }
131
    }
132
    // convert map to array
133
    // inputs: Map<string, Tensor>
134
    else {
135
      if (inputs.size !== modelInputNames.length) {
136
        throw new Error(`incorrect input map size: expected ${modelInputNames.length} but got ${inputs.size}`);
137
      }
138

139
      const sortedInputs = new Array<Tensor>(inputs.size);
140
      let sortedInputsIndex = 0;
141
      for (let i = 0; i < modelInputNames.length; ++i) {
142
        const tensor = inputs.get(modelInputNames[i]);
143
        if (!tensor) {
144
          throw new Error(`missing input tensor for: '${name}'`);
145
        }
146
        sortedInputs[sortedInputsIndex++] = tensor;
147
      }
148

149
      inputs = sortedInputs;
150
    }
151

152
    // validate dims requirements
153
    // First session run - graph input data is not cached for the session
154
    if (
155
      !this.context.graphInputTypes ||
156
      this.context.graphInputTypes.length === 0 ||
157
      !this.context.graphInputDims ||
158
      this.context.graphInputDims.length === 0
159
    ) {
160
      const modelInputIndices = this._model.graph.getInputIndices();
161
      const modelValues = this._model.graph.getValues();
162

163
      const graphInputDims = new Array<readonly number[]>(modelInputIndices.length);
164

165
      for (let i = 0; i < modelInputIndices.length; ++i) {
166
        const graphInput = modelValues[modelInputIndices[i]];
167
        graphInputDims[i] = graphInput.type!.shape.dims;
168

169
        // cached for second and subsequent runs.
170
        // Some parts of the framework works on the assumption that the graph and types and shapes are static
171
        this.context.graphInputTypes!.push(graphInput.type!.tensorType);
172
        this.context.graphInputDims!.push(inputs[i].dims);
173
      }
174

175
      this.validateInputTensorDims(graphInputDims, inputs, true);
176
    }
177

178
    // Second and subsequent session runs - graph input data is cached for the session
179
    else {
180
      this.validateInputTensorDims(this.context.graphInputDims, inputs, false);
181
    }
182

183
    // validate types requirement
184
    this.validateInputTensorTypes(this.context.graphInputTypes!, inputs);
185

186
    return inputs;
187
  }
188

189
  private validateInputTensorTypes(graphInputTypes: Tensor.DataType[], givenInputs: Tensor[]) {
190
    for (let i = 0; i < givenInputs.length; i++) {
191
      const expectedType = graphInputTypes[i];
192
      const actualType = givenInputs[i].type;
193
      if (expectedType !== actualType) {
194
        throw new Error(`input tensor[${i}] check failed: expected type '${expectedType}' but got ${actualType}`);
195
      }
196
    }
197
  }
198

199
  private validateInputTensorDims(
200
    graphInputDims: Array<readonly number[]>,
201
    givenInputs: Tensor[],
202
    noneDimSupported: boolean,
203
  ) {
204
    for (let i = 0; i < givenInputs.length; i++) {
205
      const expectedDims = graphInputDims[i];
206
      const actualDims = givenInputs[i].dims;
207
      if (!this.compareTensorDims(expectedDims, actualDims, noneDimSupported)) {
208
        throw new Error(
209
          `input tensor[${i}] check failed: expected shape '[${expectedDims.join(',')}]' but got [${actualDims.join(
210
            ',',
211
          )}]`,
212
        );
213
      }
214
    }
215
  }
216

217
  private compareTensorDims(
218
    expectedDims: readonly number[],
219
    actualDims: readonly number[],
220
    noneDimSupported: boolean,
221
  ): boolean {
222
    if (expectedDims.length !== actualDims.length) {
223
      return false;
224
    }
225

226
    for (let i = 0; i < expectedDims.length; ++i) {
227
      if (expectedDims[i] !== actualDims[i] && (!noneDimSupported || expectedDims[i] !== 0)) {
228
        // data shape mis-match AND not a 'None' dimension.
229
        return false;
230
      }
231
    }
232

233
    return true;
234
  }
235

236
  private createOutput(outputTensors: Tensor[]): Map<string, Tensor> {
237
    const modelOutputNames = this._model.graph.getOutputNames();
238
    if (outputTensors.length !== modelOutputNames.length) {
239
      throw new Error('expected number of outputs do not match number of generated outputs');
240
    }
241

242
    const output = new Map<string, Tensor>();
243
    for (let i = 0; i < modelOutputNames.length; ++i) {
244
      output.set(modelOutputNames[i], outputTensors[i]);
245
    }
246

247
    return output;
248
  }
249

250
  private initializeOps(graph: Graph): void {
251
    const nodes = graph.getNodes();
252
    this._ops = new Array(nodes.length);
253

254
    for (let i = 0; i < nodes.length; i++) {
255
      this._ops[i] = this.sessionHandler.resolve(nodes[i], this._model.opsets, graph);
256
    }
257
  }
258

259
  private _model: Model;
260
  private _initialized: boolean;
261

262
  private _ops: Operator[];
263
  private _executionPlan: ExecutionPlan;
264

265
  private backendHint?: string;
266

267
  private sessionHandler: SessionHandlerType;
268
  private context: Session.Context;
269
  private profiler: Readonly<Profiler>;
270
}
271

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.