onnxruntime
52 строки · 1.6 Кб
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4import { InferenceSession, InferenceSessionHandler, SessionHandler, Tensor } from 'onnxruntime-common';
5
6import { Session } from './session';
7import { Tensor as OnnxjsTensor } from './tensor';
8
9export class OnnxjsSessionHandler implements InferenceSessionHandler {
10constructor(private session: Session) {
11this.inputNames = this.session.inputNames;
12this.outputNames = this.session.outputNames;
13}
14
15async dispose(): Promise<void> {}
16inputNames: readonly string[];
17outputNames: readonly string[];
18async run(
19feeds: SessionHandler.FeedsType,
20_fetches: SessionHandler.FetchesType,
21_options: InferenceSession.RunOptions,
22): Promise<SessionHandler.ReturnType> {
23const inputMap = new Map<string, OnnxjsTensor>();
24for (const name in feeds) {
25if (Object.hasOwnProperty.call(feeds, name)) {
26const feed = feeds[name];
27inputMap.set(
28name,
29new OnnxjsTensor(
30feed.dims,
31feed.type as OnnxjsTensor.DataType,
32undefined,
33undefined,
34feed.data as OnnxjsTensor.NumberType,
35),
36);
37}
38}
39const outputMap = await this.session.run(inputMap);
40const output: SessionHandler.ReturnType = {};
41outputMap.forEach((tensor, name) => {
42output[name] = new Tensor(tensor.type, tensor.data, tensor.dims);
43});
44return output;
45}
46startProfiling(): void {
47this.session.startProfiling();
48}
49endProfiling(): void {
50this.session.endProfiling();
51}
52}
53