onnxruntime
228 строк · 7.2 Кб
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4import { env, InferenceSession } from 'onnxruntime-common';
5
6import {
7OrtWasmMessage,
8SerializableInternalBuffer,
9SerializableSessionMetadata,
10SerializableTensorMetadata,
11TensorMetadata,
12} from './proxy-messages';
13import * as core from './wasm-core-impl';
14import { initializeWebAssembly } from './wasm-factory';
15import { importProxyWorker } from './wasm-utils-import';
16
17const isProxy = (): boolean => !!env.wasm.proxy && typeof document !== 'undefined';
18let proxyWorker: Worker | undefined;
19let initializing = false;
20let initialized = false;
21let aborted = false;
22let temporaryObjectUrl: string | undefined;
23
24type PromiseCallbacks<T = void> = [resolve: (result: T) => void, reject: (reason: unknown) => void];
25let initWasmCallbacks: PromiseCallbacks;
26const queuedCallbacks: Map<OrtWasmMessage['type'], Array<PromiseCallbacks<unknown>>> = new Map();
27
28const enqueueCallbacks = (type: OrtWasmMessage['type'], callbacks: PromiseCallbacks<unknown>): void => {
29const queue = queuedCallbacks.get(type);
30if (queue) {
31queue.push(callbacks);
32} else {
33queuedCallbacks.set(type, [callbacks]);
34}
35};
36
37const ensureWorker = (): void => {
38if (initializing || !initialized || aborted || !proxyWorker) {
39throw new Error('worker not ready');
40}
41};
42
43const onProxyWorkerMessage = (ev: MessageEvent<OrtWasmMessage>): void => {
44switch (ev.data.type) {
45case 'init-wasm':
46initializing = false;
47if (ev.data.err) {
48aborted = true;
49initWasmCallbacks[1](ev.data.err);
50} else {
51initialized = true;
52initWasmCallbacks[0]();
53}
54if (temporaryObjectUrl) {
55URL.revokeObjectURL(temporaryObjectUrl);
56temporaryObjectUrl = undefined;
57}
58break;
59case 'init-ep':
60case 'copy-from':
61case 'create':
62case 'release':
63case 'run':
64case 'end-profiling': {
65const callbacks = queuedCallbacks.get(ev.data.type)!;
66if (ev.data.err) {
67callbacks.shift()![1](ev.data.err);
68} else {
69callbacks.shift()![0](ev.data.out!);
70}
71break;
72}
73default:
74}
75};
76
77export const initializeWebAssemblyAndOrtRuntime = async (): Promise<void> => {
78if (initialized) {
79return;
80}
81if (initializing) {
82throw new Error("multiple calls to 'initWasm()' detected.");
83}
84if (aborted) {
85throw new Error("previous call to 'initWasm()' failed.");
86}
87
88initializing = true;
89
90if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
91return new Promise<void>((resolve, reject) => {
92proxyWorker?.terminate();
93
94void importProxyWorker().then(([objectUrl, worker]) => {
95try {
96proxyWorker = worker;
97proxyWorker.onerror = (ev: ErrorEvent) => reject(ev);
98proxyWorker.onmessage = onProxyWorkerMessage;
99initWasmCallbacks = [resolve, reject];
100const message: OrtWasmMessage = { type: 'init-wasm', in: env };
101proxyWorker.postMessage(message);
102temporaryObjectUrl = objectUrl;
103} catch (e) {
104reject(e);
105}
106}, reject);
107});
108} else {
109try {
110await initializeWebAssembly(env.wasm);
111await core.initRuntime(env);
112initialized = true;
113} catch (e) {
114aborted = true;
115throw e;
116} finally {
117initializing = false;
118}
119}
120};
121
122export const initializeOrtEp = async (epName: string): Promise<void> => {
123if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
124ensureWorker();
125return new Promise<void>((resolve, reject) => {
126enqueueCallbacks('init-ep', [resolve, reject]);
127const message: OrtWasmMessage = { type: 'init-ep', in: { epName, env } };
128proxyWorker!.postMessage(message);
129});
130} else {
131await core.initEp(env, epName);
132}
133};
134
135export const copyFromExternalBuffer = async (buffer: Uint8Array): Promise<SerializableInternalBuffer> => {
136if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
137ensureWorker();
138return new Promise<SerializableInternalBuffer>((resolve, reject) => {
139enqueueCallbacks('copy-from', [resolve, reject]);
140const message: OrtWasmMessage = { type: 'copy-from', in: { buffer } };
141proxyWorker!.postMessage(message, [buffer.buffer]);
142});
143} else {
144return core.copyFromExternalBuffer(buffer);
145}
146};
147
148export const createSession = async (
149model: SerializableInternalBuffer | Uint8Array,
150options?: InferenceSession.SessionOptions,
151): Promise<SerializableSessionMetadata> => {
152if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
153// check unsupported options
154if (options?.preferredOutputLocation) {
155throw new Error('session option "preferredOutputLocation" is not supported for proxy.');
156}
157ensureWorker();
158return new Promise<SerializableSessionMetadata>((resolve, reject) => {
159enqueueCallbacks('create', [resolve, reject]);
160const message: OrtWasmMessage = { type: 'create', in: { model, options: { ...options } } };
161const transferable: Transferable[] = [];
162if (model instanceof Uint8Array) {
163transferable.push(model.buffer);
164}
165proxyWorker!.postMessage(message, transferable);
166});
167} else {
168return core.createSession(model, options);
169}
170};
171
172export const releaseSession = async (sessionId: number): Promise<void> => {
173if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
174ensureWorker();
175return new Promise<void>((resolve, reject) => {
176enqueueCallbacks('release', [resolve, reject]);
177const message: OrtWasmMessage = { type: 'release', in: sessionId };
178proxyWorker!.postMessage(message);
179});
180} else {
181core.releaseSession(sessionId);
182}
183};
184
185export const run = async (
186sessionId: number,
187inputIndices: number[],
188inputs: TensorMetadata[],
189outputIndices: number[],
190outputs: Array<TensorMetadata | null>,
191options: InferenceSession.RunOptions,
192): Promise<TensorMetadata[]> => {
193if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
194// check inputs location
195if (inputs.some((t) => t[3] !== 'cpu')) {
196throw new Error('input tensor on GPU is not supported for proxy.');
197}
198// check outputs location
199if (outputs.some((t) => t)) {
200throw new Error('pre-allocated output tensor is not supported for proxy.');
201}
202ensureWorker();
203return new Promise<SerializableTensorMetadata[]>((resolve, reject) => {
204enqueueCallbacks('run', [resolve, reject]);
205const serializableInputs = inputs as SerializableTensorMetadata[]; // every input is on CPU.
206const message: OrtWasmMessage = {
207type: 'run',
208in: { sessionId, inputIndices, inputs: serializableInputs, outputIndices, options },
209};
210proxyWorker!.postMessage(message, core.extractTransferableBuffers(serializableInputs));
211});
212} else {
213return core.run(sessionId, inputIndices, inputs, outputIndices, outputs, options);
214}
215};
216
217export const endProfiling = async (sessionId: number): Promise<void> => {
218if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
219ensureWorker();
220return new Promise<void>((resolve, reject) => {
221enqueueCallbacks('end-profiling', [resolve, reject]);
222const message: OrtWasmMessage = { type: 'end-profiling', in: sessionId };
223proxyWorker!.postMessage(message);
224});
225} else {
226core.endProfiling(sessionId);
227}
228};
229