onnxruntime
212 строк · 6.6 Кб
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4import { Env } from 'onnxruntime-common';
5
6import type { OrtWasmModule } from './wasm-types';
7import { importWasmModule } from './wasm-utils-import';
8
9let wasm: OrtWasmModule | undefined;
10let initialized = false;
11let initializing = false;
12let aborted = false;
13
14const isMultiThreadSupported = (): boolean => {
15// If 'SharedArrayBuffer' is not available, WebAssembly threads will not work.
16if (typeof SharedArrayBuffer === 'undefined') {
17return false;
18}
19
20try {
21// Test for transferability of SABs (for browsers. needed for Firefox)
22// https://groups.google.com/forum/#!msg/mozilla.dev.platform/IHkBZlHETpA/dwsMNchWEQAJ
23if (typeof MessageChannel !== 'undefined') {
24new MessageChannel().port1.postMessage(new SharedArrayBuffer(1));
25}
26
27// Test for WebAssembly threads capability (for both browsers and Node.js)
28// This typed array is a WebAssembly program containing threaded instructions.
29return WebAssembly.validate(
30new Uint8Array([
310, 97, 115, 109, 1, 0, 0, 0, 1, 4, 1, 96, 0, 0, 3, 2, 1, 0, 5, 4, 1, 3, 1, 1, 10, 11, 1, 9, 0, 65, 0, 254, 16,
322, 0, 26, 11,
33]),
34);
35} catch (e) {
36return false;
37}
38};
39
40const isSimdSupported = (): boolean => {
41try {
42// Test for WebAssembly SIMD capability (for both browsers and Node.js)
43// This typed array is a WebAssembly program containing SIMD instructions.
44
45// The binary data is generated from the following code by wat2wasm:
46//
47// (module
48// (type $t0 (func))
49// (func $f0 (type $t0)
50// (drop
51// (i32x4.dot_i16x8_s
52// (i8x16.splat
53// (i32.const 0))
54// (v128.const i32x4 0x00000000 0x00000000 0x00000000 0x00000000)))))
55
56return WebAssembly.validate(
57new Uint8Array([
580, 97, 115, 109, 1, 0, 0, 0, 1, 4, 1, 96, 0, 0, 3, 2, 1, 0, 10, 30, 1, 28, 0, 65, 0, 253, 15, 253, 12, 0, 0, 0,
590, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 253, 186, 1, 26, 11,
60]),
61);
62} catch (e) {
63return false;
64}
65};
66
67export const initializeWebAssembly = async (flags: Env.WebAssemblyFlags): Promise<void> => {
68if (initialized) {
69return Promise.resolve();
70}
71if (initializing) {
72throw new Error("multiple calls to 'initializeWebAssembly()' detected.");
73}
74if (aborted) {
75throw new Error("previous call to 'initializeWebAssembly()' failed.");
76}
77
78initializing = true;
79
80// wasm flags are already initialized
81const timeout = flags.initTimeout!;
82let numThreads = flags.numThreads!;
83
84// ensure SIMD is supported
85if (!isSimdSupported()) {
86throw new Error('WebAssembly SIMD is not supported in the current environment.');
87}
88
89// check if multi-threading is supported
90const multiThreadSupported = isMultiThreadSupported();
91if (numThreads > 1 && !multiThreadSupported) {
92if (typeof self !== 'undefined' && !self.crossOriginIsolated) {
93// eslint-disable-next-line no-console
94console.warn(
95'env.wasm.numThreads is set to ' +
96numThreads +
97', but this will not work unless you enable crossOriginIsolated mode. ' +
98'See https://web.dev/cross-origin-isolation-guide/ for more info.',
99);
100}
101
102// eslint-disable-next-line no-console
103console.warn(
104'WebAssembly multi-threading is not supported in the current environment. ' + 'Falling back to single-threading.',
105);
106
107// set flags.numThreads to 1 so that OrtInit() will not create a global thread pool.
108flags.numThreads = numThreads = 1;
109}
110
111const wasmPaths = flags.wasmPaths;
112const wasmPrefixOverride = typeof wasmPaths === 'string' ? wasmPaths : undefined;
113const mjsPathOverrideFlag = (wasmPaths as Env.WasmFilePaths)?.mjs;
114const mjsPathOverride = (mjsPathOverrideFlag as URL)?.href ?? mjsPathOverrideFlag;
115const wasmPathOverrideFlag = (wasmPaths as Env.WasmFilePaths)?.wasm;
116const wasmPathOverride = (wasmPathOverrideFlag as URL)?.href ?? wasmPathOverrideFlag;
117const wasmBinaryOverride = flags.wasmBinary;
118
119const [objectUrl, ortWasmFactory] = await importWasmModule(mjsPathOverride, wasmPrefixOverride, numThreads > 1);
120
121let isTimeout = false;
122
123const tasks: Array<Promise<void>> = [];
124
125// promise for timeout
126if (timeout > 0) {
127tasks.push(
128new Promise((resolve) => {
129setTimeout(() => {
130isTimeout = true;
131resolve();
132}, timeout);
133}),
134);
135}
136
137// promise for module initialization
138tasks.push(
139new Promise((resolve, reject) => {
140const config: Partial<OrtWasmModule> = {
141/**
142* The number of threads. WebAssembly will create (Module.numThreads - 1) workers. If it is 1, no worker will be
143* created.
144*/
145numThreads,
146};
147
148if (wasmBinaryOverride) {
149/**
150* Set a custom buffer which contains the WebAssembly binary. This will skip the wasm file fetching.
151*/
152config.wasmBinary = wasmBinaryOverride;
153} else if (wasmPathOverride || wasmPrefixOverride) {
154/**
155* A callback function to locate the WebAssembly file. The function should return the full path of the file.
156*
157* Since Emscripten 3.1.58, this function is only called for the .wasm file.
158*/
159config.locateFile = (fileName, scriptDirectory) =>
160wasmPathOverride ?? (wasmPrefixOverride ?? scriptDirectory) + fileName;
161}
162
163ortWasmFactory(config).then(
164// wasm module initialized successfully
165(module) => {
166initializing = false;
167initialized = true;
168wasm = module;
169resolve();
170if (objectUrl) {
171URL.revokeObjectURL(objectUrl);
172}
173},
174// wasm module failed to initialize
175(what) => {
176initializing = false;
177aborted = true;
178reject(what);
179},
180);
181}),
182);
183
184await Promise.race(tasks);
185
186if (isTimeout) {
187throw new Error(`WebAssembly backend initializing failed due to timeout: ${timeout}ms`);
188}
189};
190
191export const getInstance = (): OrtWasmModule => {
192if (initialized && wasm) {
193return wasm;
194}
195
196throw new Error('WebAssembly is not initialized yet.');
197};
198
199export const dispose = (): void => {
200if (initialized && !initializing && !aborted) {
201// TODO: currently "PThread.terminateAllThreads()" is not exposed in the wasm module.
202// And this function is not yet called by any code.
203// If it is needed in the future, we should expose it in the wasm module and uncomment the following line.
204
205// wasm?.PThread?.terminateAllThreads();
206wasm = undefined;
207
208initializing = false;
209initialized = false;
210aborted = true;
211}
212};
213