onnxruntime

Форк
0
/
proxy-wrapper.ts 
228 строк · 7.2 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { env, InferenceSession } from 'onnxruntime-common';
5

6
import {
7
  OrtWasmMessage,
8
  SerializableInternalBuffer,
9
  SerializableSessionMetadata,
10
  SerializableTensorMetadata,
11
  TensorMetadata,
12
} from './proxy-messages';
13
import * as core from './wasm-core-impl';
14
import { initializeWebAssembly } from './wasm-factory';
15
import { importProxyWorker } from './wasm-utils-import';
16

17
const isProxy = (): boolean => !!env.wasm.proxy && typeof document !== 'undefined';
18
let proxyWorker: Worker | undefined;
19
let initializing = false;
20
let initialized = false;
21
let aborted = false;
22
let temporaryObjectUrl: string | undefined;
23

24
type PromiseCallbacks<T = void> = [resolve: (result: T) => void, reject: (reason: unknown) => void];
25
let initWasmCallbacks: PromiseCallbacks;
26
const queuedCallbacks: Map<OrtWasmMessage['type'], Array<PromiseCallbacks<unknown>>> = new Map();
27

28
const enqueueCallbacks = (type: OrtWasmMessage['type'], callbacks: PromiseCallbacks<unknown>): void => {
29
  const queue = queuedCallbacks.get(type);
30
  if (queue) {
31
    queue.push(callbacks);
32
  } else {
33
    queuedCallbacks.set(type, [callbacks]);
34
  }
35
};
36

37
const ensureWorker = (): void => {
38
  if (initializing || !initialized || aborted || !proxyWorker) {
39
    throw new Error('worker not ready');
40
  }
41
};
42

43
const onProxyWorkerMessage = (ev: MessageEvent<OrtWasmMessage>): void => {
44
  switch (ev.data.type) {
45
    case 'init-wasm':
46
      initializing = false;
47
      if (ev.data.err) {
48
        aborted = true;
49
        initWasmCallbacks[1](ev.data.err);
50
      } else {
51
        initialized = true;
52
        initWasmCallbacks[0]();
53
      }
54
      if (temporaryObjectUrl) {
55
        URL.revokeObjectURL(temporaryObjectUrl);
56
        temporaryObjectUrl = undefined;
57
      }
58
      break;
59
    case 'init-ep':
60
    case 'copy-from':
61
    case 'create':
62
    case 'release':
63
    case 'run':
64
    case 'end-profiling': {
65
      const callbacks = queuedCallbacks.get(ev.data.type)!;
66
      if (ev.data.err) {
67
        callbacks.shift()![1](ev.data.err);
68
      } else {
69
        callbacks.shift()![0](ev.data.out!);
70
      }
71
      break;
72
    }
73
    default:
74
  }
75
};
76

77
export const initializeWebAssemblyAndOrtRuntime = async (): Promise<void> => {
78
  if (initialized) {
79
    return;
80
  }
81
  if (initializing) {
82
    throw new Error("multiple calls to 'initWasm()' detected.");
83
  }
84
  if (aborted) {
85
    throw new Error("previous call to 'initWasm()' failed.");
86
  }
87

88
  initializing = true;
89

90
  if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
91
    return new Promise<void>((resolve, reject) => {
92
      proxyWorker?.terminate();
93

94
      void importProxyWorker().then(([objectUrl, worker]) => {
95
        try {
96
          proxyWorker = worker;
97
          proxyWorker.onerror = (ev: ErrorEvent) => reject(ev);
98
          proxyWorker.onmessage = onProxyWorkerMessage;
99
          initWasmCallbacks = [resolve, reject];
100
          const message: OrtWasmMessage = { type: 'init-wasm', in: env };
101
          proxyWorker.postMessage(message);
102
          temporaryObjectUrl = objectUrl;
103
        } catch (e) {
104
          reject(e);
105
        }
106
      }, reject);
107
    });
108
  } else {
109
    try {
110
      await initializeWebAssembly(env.wasm);
111
      await core.initRuntime(env);
112
      initialized = true;
113
    } catch (e) {
114
      aborted = true;
115
      throw e;
116
    } finally {
117
      initializing = false;
118
    }
119
  }
120
};
121

122
export const initializeOrtEp = async (epName: string): Promise<void> => {
123
  if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
124
    ensureWorker();
125
    return new Promise<void>((resolve, reject) => {
126
      enqueueCallbacks('init-ep', [resolve, reject]);
127
      const message: OrtWasmMessage = { type: 'init-ep', in: { epName, env } };
128
      proxyWorker!.postMessage(message);
129
    });
130
  } else {
131
    await core.initEp(env, epName);
132
  }
133
};
134

135
export const copyFromExternalBuffer = async (buffer: Uint8Array): Promise<SerializableInternalBuffer> => {
136
  if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
137
    ensureWorker();
138
    return new Promise<SerializableInternalBuffer>((resolve, reject) => {
139
      enqueueCallbacks('copy-from', [resolve, reject]);
140
      const message: OrtWasmMessage = { type: 'copy-from', in: { buffer } };
141
      proxyWorker!.postMessage(message, [buffer.buffer]);
142
    });
143
  } else {
144
    return core.copyFromExternalBuffer(buffer);
145
  }
146
};
147

148
export const createSession = async (
149
  model: SerializableInternalBuffer | Uint8Array,
150
  options?: InferenceSession.SessionOptions,
151
): Promise<SerializableSessionMetadata> => {
152
  if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
153
    // check unsupported options
154
    if (options?.preferredOutputLocation) {
155
      throw new Error('session option "preferredOutputLocation" is not supported for proxy.');
156
    }
157
    ensureWorker();
158
    return new Promise<SerializableSessionMetadata>((resolve, reject) => {
159
      enqueueCallbacks('create', [resolve, reject]);
160
      const message: OrtWasmMessage = { type: 'create', in: { model, options: { ...options } } };
161
      const transferable: Transferable[] = [];
162
      if (model instanceof Uint8Array) {
163
        transferable.push(model.buffer);
164
      }
165
      proxyWorker!.postMessage(message, transferable);
166
    });
167
  } else {
168
    return core.createSession(model, options);
169
  }
170
};
171

172
export const releaseSession = async (sessionId: number): Promise<void> => {
173
  if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
174
    ensureWorker();
175
    return new Promise<void>((resolve, reject) => {
176
      enqueueCallbacks('release', [resolve, reject]);
177
      const message: OrtWasmMessage = { type: 'release', in: sessionId };
178
      proxyWorker!.postMessage(message);
179
    });
180
  } else {
181
    core.releaseSession(sessionId);
182
  }
183
};
184

185
export const run = async (
186
  sessionId: number,
187
  inputIndices: number[],
188
  inputs: TensorMetadata[],
189
  outputIndices: number[],
190
  outputs: Array<TensorMetadata | null>,
191
  options: InferenceSession.RunOptions,
192
): Promise<TensorMetadata[]> => {
193
  if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
194
    // check inputs location
195
    if (inputs.some((t) => t[3] !== 'cpu')) {
196
      throw new Error('input tensor on GPU is not supported for proxy.');
197
    }
198
    // check outputs location
199
    if (outputs.some((t) => t)) {
200
      throw new Error('pre-allocated output tensor is not supported for proxy.');
201
    }
202
    ensureWorker();
203
    return new Promise<SerializableTensorMetadata[]>((resolve, reject) => {
204
      enqueueCallbacks('run', [resolve, reject]);
205
      const serializableInputs = inputs as SerializableTensorMetadata[]; // every input is on CPU.
206
      const message: OrtWasmMessage = {
207
        type: 'run',
208
        in: { sessionId, inputIndices, inputs: serializableInputs, outputIndices, options },
209
      };
210
      proxyWorker!.postMessage(message, core.extractTransferableBuffers(serializableInputs));
211
    });
212
  } else {
213
    return core.run(sessionId, inputIndices, inputs, outputIndices, outputs, options);
214
  }
215
};
216

217
export const endProfiling = async (sessionId: number): Promise<void> => {
218
  if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
219
    ensureWorker();
220
    return new Promise<void>((resolve, reject) => {
221
      enqueueCallbacks('end-profiling', [resolve, reject]);
222
      const message: OrtWasmMessage = { type: 'end-profiling', in: sessionId };
223
      proxyWorker!.postMessage(message);
224
    });
225
  } else {
226
    core.endProfiling(sessionId);
227
  }
228
};
229

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.