onnxruntime

Форк
0
/
wasm-factory.ts 
212 строк · 6.6 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { Env } from 'onnxruntime-common';
5

6
import type { OrtWasmModule } from './wasm-types';
7
import { importWasmModule } from './wasm-utils-import';
8

9
let wasm: OrtWasmModule | undefined;
10
let initialized = false;
11
let initializing = false;
12
let aborted = false;
13

14
const isMultiThreadSupported = (): boolean => {
15
  // If 'SharedArrayBuffer' is not available, WebAssembly threads will not work.
16
  if (typeof SharedArrayBuffer === 'undefined') {
17
    return false;
18
  }
19

20
  try {
21
    // Test for transferability of SABs (for browsers. needed for Firefox)
22
    // https://groups.google.com/forum/#!msg/mozilla.dev.platform/IHkBZlHETpA/dwsMNchWEQAJ
23
    if (typeof MessageChannel !== 'undefined') {
24
      new MessageChannel().port1.postMessage(new SharedArrayBuffer(1));
25
    }
26

27
    // Test for WebAssembly threads capability (for both browsers and Node.js)
28
    // This typed array is a WebAssembly program containing threaded instructions.
29
    return WebAssembly.validate(
30
      new Uint8Array([
31
        0, 97, 115, 109, 1, 0, 0, 0, 1, 4, 1, 96, 0, 0, 3, 2, 1, 0, 5, 4, 1, 3, 1, 1, 10, 11, 1, 9, 0, 65, 0, 254, 16,
32
        2, 0, 26, 11,
33
      ]),
34
    );
35
  } catch (e) {
36
    return false;
37
  }
38
};
39

40
const isSimdSupported = (): boolean => {
41
  try {
42
    // Test for WebAssembly SIMD capability (for both browsers and Node.js)
43
    // This typed array is a WebAssembly program containing SIMD instructions.
44

45
    // The binary data is generated from the following code by wat2wasm:
46
    //
47
    // (module
48
    //   (type $t0 (func))
49
    //   (func $f0 (type $t0)
50
    //     (drop
51
    //       (i32x4.dot_i16x8_s
52
    //         (i8x16.splat
53
    //           (i32.const 0))
54
    //         (v128.const i32x4 0x00000000 0x00000000 0x00000000 0x00000000)))))
55

56
    return WebAssembly.validate(
57
      new Uint8Array([
58
        0, 97, 115, 109, 1, 0, 0, 0, 1, 4, 1, 96, 0, 0, 3, 2, 1, 0, 10, 30, 1, 28, 0, 65, 0, 253, 15, 253, 12, 0, 0, 0,
59
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 253, 186, 1, 26, 11,
60
      ]),
61
    );
62
  } catch (e) {
63
    return false;
64
  }
65
};
66

67
export const initializeWebAssembly = async (flags: Env.WebAssemblyFlags): Promise<void> => {
68
  if (initialized) {
69
    return Promise.resolve();
70
  }
71
  if (initializing) {
72
    throw new Error("multiple calls to 'initializeWebAssembly()' detected.");
73
  }
74
  if (aborted) {
75
    throw new Error("previous call to 'initializeWebAssembly()' failed.");
76
  }
77

78
  initializing = true;
79

80
  // wasm flags are already initialized
81
  const timeout = flags.initTimeout!;
82
  let numThreads = flags.numThreads!;
83

84
  // ensure SIMD is supported
85
  if (!isSimdSupported()) {
86
    throw new Error('WebAssembly SIMD is not supported in the current environment.');
87
  }
88

89
  // check if multi-threading is supported
90
  const multiThreadSupported = isMultiThreadSupported();
91
  if (numThreads > 1 && !multiThreadSupported) {
92
    if (typeof self !== 'undefined' && !self.crossOriginIsolated) {
93
      // eslint-disable-next-line no-console
94
      console.warn(
95
        'env.wasm.numThreads is set to ' +
96
          numThreads +
97
          ', but this will not work unless you enable crossOriginIsolated mode. ' +
98
          'See https://web.dev/cross-origin-isolation-guide/ for more info.',
99
      );
100
    }
101

102
    // eslint-disable-next-line no-console
103
    console.warn(
104
      'WebAssembly multi-threading is not supported in the current environment. ' + 'Falling back to single-threading.',
105
    );
106

107
    // set flags.numThreads to 1 so that OrtInit() will not create a global thread pool.
108
    flags.numThreads = numThreads = 1;
109
  }
110

111
  const wasmPaths = flags.wasmPaths;
112
  const wasmPrefixOverride = typeof wasmPaths === 'string' ? wasmPaths : undefined;
113
  const mjsPathOverrideFlag = (wasmPaths as Env.WasmFilePaths)?.mjs;
114
  const mjsPathOverride = (mjsPathOverrideFlag as URL)?.href ?? mjsPathOverrideFlag;
115
  const wasmPathOverrideFlag = (wasmPaths as Env.WasmFilePaths)?.wasm;
116
  const wasmPathOverride = (wasmPathOverrideFlag as URL)?.href ?? wasmPathOverrideFlag;
117
  const wasmBinaryOverride = flags.wasmBinary;
118

119
  const [objectUrl, ortWasmFactory] = await importWasmModule(mjsPathOverride, wasmPrefixOverride, numThreads > 1);
120

121
  let isTimeout = false;
122

123
  const tasks: Array<Promise<void>> = [];
124

125
  // promise for timeout
126
  if (timeout > 0) {
127
    tasks.push(
128
      new Promise((resolve) => {
129
        setTimeout(() => {
130
          isTimeout = true;
131
          resolve();
132
        }, timeout);
133
      }),
134
    );
135
  }
136

137
  // promise for module initialization
138
  tasks.push(
139
    new Promise((resolve, reject) => {
140
      const config: Partial<OrtWasmModule> = {
141
        /**
142
         * The number of threads. WebAssembly will create (Module.numThreads - 1) workers. If it is 1, no worker will be
143
         * created.
144
         */
145
        numThreads,
146
      };
147

148
      if (wasmBinaryOverride) {
149
        /**
150
         * Set a custom buffer which contains the WebAssembly binary. This will skip the wasm file fetching.
151
         */
152
        config.wasmBinary = wasmBinaryOverride;
153
      } else if (wasmPathOverride || wasmPrefixOverride) {
154
        /**
155
         * A callback function to locate the WebAssembly file. The function should return the full path of the file.
156
         *
157
         * Since Emscripten 3.1.58, this function is only called for the .wasm file.
158
         */
159
        config.locateFile = (fileName, scriptDirectory) =>
160
          wasmPathOverride ?? (wasmPrefixOverride ?? scriptDirectory) + fileName;
161
      }
162

163
      ortWasmFactory(config).then(
164
        // wasm module initialized successfully
165
        (module) => {
166
          initializing = false;
167
          initialized = true;
168
          wasm = module;
169
          resolve();
170
          if (objectUrl) {
171
            URL.revokeObjectURL(objectUrl);
172
          }
173
        },
174
        // wasm module failed to initialize
175
        (what) => {
176
          initializing = false;
177
          aborted = true;
178
          reject(what);
179
        },
180
      );
181
    }),
182
  );
183

184
  await Promise.race(tasks);
185

186
  if (isTimeout) {
187
    throw new Error(`WebAssembly backend initializing failed due to timeout: ${timeout}ms`);
188
  }
189
};
190

191
export const getInstance = (): OrtWasmModule => {
192
  if (initialized && wasm) {
193
    return wasm;
194
  }
195

196
  throw new Error('WebAssembly is not initialized yet.');
197
};
198

199
export const dispose = (): void => {
200
  if (initialized && !initializing && !aborted) {
201
    // TODO: currently "PThread.terminateAllThreads()" is not exposed in the wasm module.
202
    //       And this function is not yet called by any code.
203
    //       If it is needed in the future, we should expose it in the wasm module and uncomment the following line.
204

205
    // wasm?.PThread?.terminateAllThreads();
206
    wasm = undefined;
207

208
    initializing = false;
209
    initialized = false;
210
    aborted = true;
211
  }
212
};
213

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.