onnxruntime

Форк
0
/
wasm-core-impl.ts 
844 строки · 29.6 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
// WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from
5
// WebNN API specification.
6
// https://github.com/webmachinelearning/webnn/issues/677
7
/// <reference path="jsep/webnn/webnn.d.ts" />
8

9
import { Env, InferenceSession, Tensor } from 'onnxruntime-common';
10

11
import {
12
  SerializableInternalBuffer,
13
  SerializableSessionMetadata,
14
  SerializableTensorMetadata,
15
  TensorMetadata,
16
} from './proxy-messages';
17
import { setRunOptions } from './run-options';
18
import { setSessionOptions } from './session-options';
19
import {
20
  calculateTensorSizeInBytes,
21
  dataLocationStringToEnum,
22
  isGpuBufferSupportedType,
23
  logLevelStringToEnum,
24
  tensorDataTypeEnumToString,
25
  tensorDataTypeStringToEnum,
26
  tensorTypeToTypedArrayConstructor,
27
} from './wasm-common';
28
import { getInstance } from './wasm-factory';
29
import { allocWasmString, checkLastError } from './wasm-utils';
30
import { loadFile } from './wasm-utils-load-file';
31

32
// #region Initializations
33

34
/**
35
 * There are 4 different "initialization" steps for ORT. They happen in different places and different time.
36
 *
37
 * 1. JavaScript initialization for onnxruntime-common and onnxruntime-web.
38
 *    This is the first initialization step. In this step, onnxruntime-web calls onnxruntime-common's registerBackend()
39
 * function multiple times to register all the available backends. The backend registration is very fast. It only
40
 * registers the backend name with the uninitialized backend object. No heavy initialization is done in this step.
41
 *    Refer to web/lib/index.ts for the backend registration.
42
 *
43
 * 2. WebAssembly artifact initialization.
44
 *    This happens when any registered wasm backend is used for the first time (ie. `ort.InferenceSession.create()` or
45
 * `ort.TrainingSession.create()` is called). In this step, onnxruntime-web does the followings:
46
 *     - create a proxy worker and make sure the proxy worker is ready to receive messages, if proxy is enabled.
47
 *     - perform feature detection, locate correct WebAssembly artifact path and call the Emscripten generated
48
 * JavaScript code to initialize the WebAssembly runtime.
49
 *         - if proxy is enabled, this step happens in the proxy worker using message 'init-wasm'.
50
 *         - downloading the 'ort-wasm{...}.wasm' file is done in this step.
51
 *         - if multi-thread is enabled, one or more webworker will be created to initialize the PThread threadpool.
52
 *
53
 * 3. ORT environment initialization.
54
 *    This happens after step 2. In this step, onnxruntime-web performs ONNX Runtime environment initialization.
55
 * Function `_OrtInit()` is called in this step.
56
 *     - if proxy is enabled, this step happens in the proxy worker using message 'init-ort'.
57
 *     - logging level (ort.env.logLevel) and thread number (ort.env.wasm.numThreads) are set in this step.
58
 *
59
 * 4. Session initialization.
60
 *    This happens when `ort.InferenceSession.create()` or `ort.TrainingSession.create()` is called. Unlike the first 3
61
 * steps (they only called once), this step will be done for each session. In this step, onnxruntime-web does the
62
 * followings:
63
 *    If the parameter is a URL:
64
 *    - download the model data from the URL.
65
 *    - copy the model data to the WASM heap. (proxy: 'copy-from')
66
 *    - dereference the model buffer. This step allows the original ArrayBuffer to be garbage collected.
67
 *    - call `_OrtCreateSession()` to create the session. (proxy: 'create')
68
 *
69
 *    If the parameter is a Uint8Array object:
70
 *    - copy the model data to the WASM heap. (proxy: 'copy-from')
71
 *    - call `_OrtCreateSession()` to create the session. (proxy: 'create')
72
 *
73
 *
74
 */
75

76
/**
77
 * initialize ORT environment.
78
 *
79
 * @param numThreads SetGlobalIntraOpNumThreads(numThreads)
80
 * @param loggingLevel CreateEnv(static_cast<OrtLoggingLevel>(logging_level))
81
 */
82
const initOrt = (numThreads: number, loggingLevel: number): void => {
83
  const errorCode = getInstance()._OrtInit(numThreads, loggingLevel);
84
  if (errorCode !== 0) {
85
    checkLastError("Can't initialize onnxruntime.");
86
  }
87
};
88

89
/**
90
 * initialize runtime environment.
91
 * @param env passed in the environment config object.
92
 */
93
export const initRuntime = async (env: Env): Promise<void> => {
94
  // init ORT
95
  initOrt(env.wasm.numThreads!, logLevelStringToEnum(env.logLevel));
96
};
97

98
/**
99
 * perform EP specific initialization.
100
 *
101
 * @param env
102
 * @param epName
103
 */
104
export const initEp = async (env: Env, epName: string): Promise<void> => {
105
  if (!BUILD_DEFS.DISABLE_JSEP) {
106
    // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires
107
    const initJsep = require('./jsep/init').init;
108

109
    if (epName === 'webgpu') {
110
      // perform WebGPU availability check
111
      if (typeof navigator === 'undefined' || !navigator.gpu) {
112
        throw new Error('WebGPU is not supported in current environment');
113
      }
114

115
      let adapter = env.webgpu.adapter as GPUAdapter | null;
116
      if (!adapter) {
117
        // if adapter is not set, request a new adapter.
118
        const powerPreference = env.webgpu.powerPreference;
119
        if (
120
          powerPreference !== undefined &&
121
          powerPreference !== 'low-power' &&
122
          powerPreference !== 'high-performance'
123
        ) {
124
          throw new Error(`Invalid powerPreference setting: "${powerPreference}"`);
125
        }
126
        const forceFallbackAdapter = env.webgpu.forceFallbackAdapter;
127
        if (forceFallbackAdapter !== undefined && typeof forceFallbackAdapter !== 'boolean') {
128
          throw new Error(`Invalid forceFallbackAdapter setting: "${forceFallbackAdapter}"`);
129
        }
130
        adapter = await navigator.gpu.requestAdapter({ powerPreference, forceFallbackAdapter });
131
        if (!adapter) {
132
          throw new Error(
133
            'Failed to get GPU adapter. ' +
134
              'You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.',
135
          );
136
        }
137
      } else {
138
        // if adapter is set, validate it.
139
        if (
140
          typeof adapter.limits !== 'object' ||
141
          typeof adapter.features !== 'object' ||
142
          typeof adapter.requestDevice !== 'function'
143
        ) {
144
          throw new Error('Invalid GPU adapter set in `env.webgpu.adapter`. It must be a GPUAdapter object.');
145
        }
146
      }
147

148
      await initJsep('webgpu', getInstance(), env, adapter);
149
    }
150
    if (epName === 'webnn') {
151
      // perform WebNN availability check
152
      if (typeof navigator === 'undefined' || !(navigator as unknown as { ml: unknown }).ml) {
153
        throw new Error('WebNN is not supported in current environment');
154
      }
155

156
      await initJsep('webnn', getInstance(), env);
157
    }
158
  }
159
};
160

161
// #endregion Initializations
162

163
/**
164
 * valid data locations for input/output tensors.
165
 */
166
type SupportedTensorDataLocationForInputOutput = 'cpu' | 'cpu-pinned' | 'gpu-buffer';
167

168
type IOBindingState = {
169
  /**
170
   * the handle of IO binding.
171
   */
172
  readonly handle: number;
173

174
  /**
175
   * the preferred location for each output tensor.
176
   *
177
   * value is one of 'cpu', 'cpu-pinned', 'gpu-buffer'.
178
   */
179
  readonly outputPreferredLocations: readonly SupportedTensorDataLocationForInputOutput[];
180

181
  /**
182
   * enum value of the preferred location for each output tensor.
183
   */
184
  readonly outputPreferredLocationsEncoded: readonly number[];
185
};
186

187
/**
188
 *  tuple elements are: InferenceSession ID; inputNamesUTF8Encoded; outputNamesUTF8Encoded; bindingState
189
 */
190
type SessionMetadata = [
191
  inferenceSessionId: number,
192
  inputNamesUTF8Encoded: number[],
193
  outputNamesUTF8Encoded: number[],
194
  bindingState: IOBindingState | null,
195
  enableGraphCapture: boolean,
196
  inputOutputBound: boolean,
197
];
198

199
const activeSessions = new Map<number, SessionMetadata>();
200

201
/**
202
 * get the input/output count of the session.
203
 * @param sessionHandle the handle representing the session. should be non-zero.
204
 * @returns a tuple including 2 numbers, representing the input count and output count.
205
 */
206
const getSessionInputOutputCount = (sessionHandle: number): [number, number] => {
207
  const wasm = getInstance();
208
  const stack = wasm.stackSave();
209
  try {
210
    const dataOffset = wasm.stackAlloc(8);
211
    const errorCode = wasm._OrtGetInputOutputCount(sessionHandle, dataOffset, dataOffset + 4);
212
    if (errorCode !== 0) {
213
      checkLastError("Can't get session input/output count.");
214
    }
215
    return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]];
216
  } finally {
217
    wasm.stackRestore(stack);
218
  }
219
};
220

221
/**
222
 * allocate the memory and memcpy the external buffer.
223
 *
224
 * @param model - the external buffer containing the model data. Must not be the same buffer as the WASM heap.
225
 * @returns a 2-elements tuple - the pointer and size of the allocated buffer
226
 */
227
export const copyFromExternalBuffer = (model: Uint8Array): [number, number] => {
228
  const wasm = getInstance();
229
  const modelDataOffset = wasm._malloc(model.byteLength);
230
  if (modelDataOffset === 0) {
231
    throw new Error(`Can't create a session. failed to allocate a buffer of size ${model.byteLength}.`);
232
  }
233
  wasm.HEAPU8.set(model, modelDataOffset);
234
  return [modelDataOffset, model.byteLength];
235
};
236

237
/**
238
 * create an inference session from a model data buffer.
239
 *
240
 * @param modelData - either a Uint8Array object representing the model data, or a 2-elements tuple containing the
241
 *     pointer and size of the model data buffer.
242
 * @param options an optional session options object.
243
 * @returns a 3-elements tuple containing [session handle, input names, output names]
244
 */
245
export const createSession = async (
246
  modelData: Uint8Array | SerializableInternalBuffer,
247
  options?: InferenceSession.SessionOptions,
248
): Promise<SerializableSessionMetadata> => {
249
  let modelDataOffset: number, modelDataLength: number;
250
  const wasm = getInstance();
251

252
  if (Array.isArray(modelData)) {
253
    // if model data is an array, it must be a 2-elements tuple containing the pointer and size of the model data
254
    [modelDataOffset, modelDataLength] = modelData;
255
  } else if (modelData.buffer === wasm.HEAPU8.buffer) {
256
    // if model data uses the same buffer as the WASM heap, we don't need to copy it.
257
    [modelDataOffset, modelDataLength] = [modelData.byteOffset, modelData.byteLength];
258
  } else {
259
    // otherwise, copy the model data to the WASM heap.
260
    [modelDataOffset, modelDataLength] = copyFromExternalBuffer(modelData);
261
  }
262

263
  let sessionHandle = 0;
264
  let sessionOptionsHandle = 0;
265
  let ioBindingHandle = 0;
266
  let allocs: number[] = [];
267
  const inputNamesUTF8Encoded = [];
268
  const outputNamesUTF8Encoded = [];
269

270
  try {
271
    [sessionOptionsHandle, allocs] = setSessionOptions(options);
272

273
    if (options?.externalData && wasm.mountExternalData) {
274
      const loadingPromises = [];
275
      for (const file of options.externalData) {
276
        const path = typeof file === 'string' ? file : file.path;
277
        loadingPromises.push(
278
          loadFile(typeof file === 'string' ? file : file.data).then((data) => {
279
            wasm.mountExternalData!(path, data);
280
          }),
281
        );
282
      }
283

284
      // wait for all external data files to be loaded
285
      await Promise.all(loadingPromises);
286
    }
287

288
    for (const provider of options?.executionProviders ?? []) {
289
      const providerName = typeof provider === 'string' ? provider : provider.name;
290
      if (providerName === 'webnn') {
291
        if (wasm.currentContext) {
292
          throw new Error('WebNN execution provider is already set.');
293
        }
294
        if (typeof provider !== 'string') {
295
          const webnnOptions = provider as InferenceSession.WebNNExecutionProviderOption;
296
          const context = (webnnOptions as InferenceSession.WebNNOptionsWithMLContext)?.context;
297
          const gpuDevice = (webnnOptions as InferenceSession.WebNNOptionsWebGpu)?.gpuDevice;
298
          const deviceType = (webnnOptions as InferenceSession.WebNNContextOptions)?.deviceType;
299
          const numThreads = (webnnOptions as InferenceSession.WebNNContextOptions)?.numThreads;
300
          const powerPreference = (webnnOptions as InferenceSession.WebNNContextOptions)?.powerPreference;
301
          if (context) {
302
            wasm.currentContext = context as MLContext;
303
          } else if (gpuDevice) {
304
            wasm.currentContext = await navigator.ml.createContext(gpuDevice);
305
          } else {
306
            wasm.currentContext = await navigator.ml.createContext({ deviceType, numThreads, powerPreference });
307
          }
308
        } else {
309
          wasm.currentContext = await navigator.ml.createContext();
310
        }
311
        break;
312
      }
313
    }
314

315
    sessionHandle = await wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle);
316
    if (sessionHandle === 0) {
317
      checkLastError("Can't create a session.");
318
    }
319

320
    // clear current MLContext after session creation
321
    if (wasm.currentContext) {
322
      wasm.currentContext = undefined;
323
    }
324

325
    const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle);
326

327
    const enableGraphCapture = !!options?.enableGraphCapture;
328

329
    const inputNames = [];
330
    const outputNames = [];
331
    const outputPreferredLocations: SupportedTensorDataLocationForInputOutput[] = [];
332
    for (let i = 0; i < inputCount; i++) {
333
      const name = wasm._OrtGetInputName(sessionHandle, i);
334
      if (name === 0) {
335
        checkLastError("Can't get an input name.");
336
      }
337
      inputNamesUTF8Encoded.push(name);
338
      inputNames.push(wasm.UTF8ToString(name));
339
    }
340
    for (let i = 0; i < outputCount; i++) {
341
      const name = wasm._OrtGetOutputName(sessionHandle, i);
342
      if (name === 0) {
343
        checkLastError("Can't get an output name.");
344
      }
345
      outputNamesUTF8Encoded.push(name);
346
      const nameString = wasm.UTF8ToString(name);
347
      outputNames.push(nameString);
348

349
      if (!BUILD_DEFS.DISABLE_JSEP) {
350
        if (enableGraphCapture && options?.preferredOutputLocation === undefined) {
351
          outputPreferredLocations.push('gpu-buffer');
352
          continue;
353
        }
354
        const location =
355
          typeof options?.preferredOutputLocation === 'string'
356
            ? options.preferredOutputLocation
357
            : (options?.preferredOutputLocation?.[nameString] ?? 'cpu');
358
        if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer') {
359
          throw new Error(`Not supported preferred output location: ${location}.`);
360
        }
361
        if (enableGraphCapture && location !== 'gpu-buffer') {
362
          throw new Error(
363
            `Not supported preferred output location: ${location}. Only 'gpu-buffer' location is supported when enableGraphCapture is true.`,
364
          );
365
        }
366
        outputPreferredLocations.push(location);
367
      }
368
    }
369

370
    // use IO binding only when at least one output is preffered to be on GPU.
371
    let bindingState: IOBindingState | null = null;
372
    if (!BUILD_DEFS.DISABLE_JSEP && outputPreferredLocations.some((l) => l === 'gpu-buffer')) {
373
      ioBindingHandle = wasm._OrtCreateBinding(sessionHandle);
374
      if (ioBindingHandle === 0) {
375
        checkLastError("Can't create IO binding.");
376
      }
377

378
      bindingState = {
379
        handle: ioBindingHandle,
380
        outputPreferredLocations,
381
        outputPreferredLocationsEncoded: outputPreferredLocations.map((l) => dataLocationStringToEnum(l)),
382
      };
383
    }
384

385
    activeSessions.set(sessionHandle, [
386
      sessionHandle,
387
      inputNamesUTF8Encoded,
388
      outputNamesUTF8Encoded,
389
      bindingState,
390
      enableGraphCapture,
391
      false,
392
    ]);
393
    return [sessionHandle, inputNames, outputNames];
394
  } catch (e) {
395
    inputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf));
396
    outputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf));
397

398
    if (ioBindingHandle !== 0) {
399
      wasm._OrtReleaseBinding(ioBindingHandle);
400
    }
401

402
    if (sessionHandle !== 0) {
403
      wasm._OrtReleaseSession(sessionHandle);
404
    }
405
    throw e;
406
  } finally {
407
    wasm._free(modelDataOffset);
408
    if (sessionOptionsHandle !== 0) {
409
      wasm._OrtReleaseSessionOptions(sessionOptionsHandle);
410
    }
411
    allocs.forEach((alloc) => wasm._free(alloc));
412

413
    // unmount external data if necessary
414
    wasm.unmountExternalData?.();
415
  }
416
};
417

418
export const releaseSession = (sessionId: number): void => {
419
  const wasm = getInstance();
420
  const session = activeSessions.get(sessionId);
421
  if (!session) {
422
    throw new Error(`cannot release session. invalid session id: ${sessionId}`);
423
  }
424
  const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture] = session;
425

426
  if (ioBindingState) {
427
    if (enableGraphCapture) {
428
      wasm._OrtClearBoundOutputs(ioBindingState.handle);
429
    }
430
    wasm._OrtReleaseBinding(ioBindingState.handle);
431
  }
432

433
  wasm.jsepOnReleaseSession?.(sessionId);
434

435
  inputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf));
436
  outputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf));
437
  wasm._OrtReleaseSession(sessionHandle);
438
  activeSessions.delete(sessionId);
439
};
440

441
export const prepareInputOutputTensor = (
442
  tensor: TensorMetadata | null,
443
  tensorHandles: number[],
444
  allocs: number[],
445
  sessionId: number,
446
  index: number,
447
  enableGraphCapture = false,
448
): void => {
449
  if (!tensor) {
450
    tensorHandles.push(0);
451
    return;
452
  }
453

454
  const wasm = getInstance();
455

456
  const dataType = tensor[0];
457
  const dims = tensor[1];
458
  const location = tensor[3];
459

460
  let rawData: number;
461
  let dataByteLength: number;
462

463
  if (dataType === 'string' && location === 'gpu-buffer') {
464
    throw new Error('String tensor is not supported on GPU.');
465
  }
466

467
  if (enableGraphCapture && location !== 'gpu-buffer') {
468
    throw new Error(
469
      `External buffer must be provided for input/output index ${index} when enableGraphCapture is true.`,
470
    );
471
  }
472

473
  if (location === 'gpu-buffer') {
474
    const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer;
475
    dataByteLength = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(dataType), dims)!;
476

477
    const registerBuffer = wasm.jsepRegisterBuffer;
478
    if (!registerBuffer) {
479
      throw new Error('Tensor location "gpu-buffer" is not supported without using WebGPU.');
480
    }
481
    rawData = registerBuffer(sessionId, index, gpuBuffer, dataByteLength);
482
  } else {
483
    const data = tensor[2];
484

485
    if (Array.isArray(data)) {
486
      // string tensor
487
      dataByteLength = 4 * data.length;
488
      rawData = wasm._malloc(dataByteLength);
489
      allocs.push(rawData);
490
      let dataIndex = rawData / 4;
491
      for (let i = 0; i < data.length; i++) {
492
        if (typeof data[i] !== 'string') {
493
          throw new TypeError(`tensor data at index ${i} is not a string`);
494
        }
495
        wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs);
496
      }
497
    } else {
498
      dataByteLength = data.byteLength;
499
      rawData = wasm._malloc(dataByteLength);
500
      allocs.push(rawData);
501
      wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData);
502
    }
503
  }
504

505
  const stack = wasm.stackSave();
506
  const dimsOffset = wasm.stackAlloc(4 * dims.length);
507
  try {
508
    let dimIndex = dimsOffset / 4;
509
    dims.forEach((d) => (wasm.HEAP32[dimIndex++] = d));
510
    const tensor = wasm._OrtCreateTensor(
511
      tensorDataTypeStringToEnum(dataType),
512
      rawData,
513
      dataByteLength,
514
      dimsOffset,
515
      dims.length,
516
      dataLocationStringToEnum(location),
517
    );
518
    if (tensor === 0) {
519
      checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`);
520
    }
521
    tensorHandles.push(tensor);
522
  } finally {
523
    wasm.stackRestore(stack);
524
  }
525
};
526

527
/**
528
 * perform inference run
529
 */
530
export const run = async (
531
  sessionId: number,
532
  inputIndices: number[],
533
  inputTensors: TensorMetadata[],
534
  outputIndices: number[],
535
  outputTensors: Array<TensorMetadata | null>,
536
  options: InferenceSession.RunOptions,
537
): Promise<TensorMetadata[]> => {
538
  const wasm = getInstance();
539
  const session = activeSessions.get(sessionId);
540
  if (!session) {
541
    throw new Error(`cannot run inference. invalid session id: ${sessionId}`);
542
  }
543
  const sessionHandle = session[0];
544
  const inputNamesUTF8Encoded = session[1];
545
  const outputNamesUTF8Encoded = session[2];
546
  const ioBindingState = session[3];
547
  const enableGraphCapture = session[4];
548
  const inputOutputBound = session[5];
549

550
  const inputCount = inputIndices.length;
551
  const outputCount = outputIndices.length;
552

553
  let runOptionsHandle = 0;
554
  let runOptionsAllocs: number[] = [];
555

556
  const inputTensorHandles: number[] = [];
557
  const outputTensorHandles: number[] = [];
558
  const inputOutputAllocs: number[] = [];
559

560
  const beforeRunStack = wasm.stackSave();
561
  const inputValuesOffset = wasm.stackAlloc(inputCount * 4);
562
  const inputNamesOffset = wasm.stackAlloc(inputCount * 4);
563
  const outputValuesOffset = wasm.stackAlloc(outputCount * 4);
564
  const outputNamesOffset = wasm.stackAlloc(outputCount * 4);
565

566
  try {
567
    [runOptionsHandle, runOptionsAllocs] = setRunOptions(options);
568

569
    // create input tensors
570
    for (let i = 0; i < inputCount; i++) {
571
      prepareInputOutputTensor(
572
        inputTensors[i],
573
        inputTensorHandles,
574
        inputOutputAllocs,
575
        sessionId,
576
        inputIndices[i],
577
        enableGraphCapture,
578
      );
579
    }
580

581
    // create output tensors
582
    for (let i = 0; i < outputCount; i++) {
583
      prepareInputOutputTensor(
584
        outputTensors[i],
585
        outputTensorHandles,
586
        inputOutputAllocs,
587
        sessionId,
588
        inputCount + outputIndices[i],
589
        enableGraphCapture,
590
      );
591
    }
592

593
    let inputValuesIndex = inputValuesOffset / 4;
594
    let inputNamesIndex = inputNamesOffset / 4;
595
    let outputValuesIndex = outputValuesOffset / 4;
596
    let outputNamesIndex = outputNamesOffset / 4;
597
    for (let i = 0; i < inputCount; i++) {
598
      wasm.HEAPU32[inputValuesIndex++] = inputTensorHandles[i];
599
      wasm.HEAPU32[inputNamesIndex++] = inputNamesUTF8Encoded[inputIndices[i]];
600
    }
601
    for (let i = 0; i < outputCount; i++) {
602
      wasm.HEAPU32[outputValuesIndex++] = outputTensorHandles[i];
603
      wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]];
604
    }
605

606
    if (!BUILD_DEFS.DISABLE_JSEP && ioBindingState && !inputOutputBound) {
607
      const { handle, outputPreferredLocations, outputPreferredLocationsEncoded } = ioBindingState;
608

609
      if (inputNamesUTF8Encoded.length !== inputCount) {
610
        throw new Error(
611
          `input count from feeds (${inputCount}) is expected to be always equal to model's input count (${inputNamesUTF8Encoded.length}).`,
612
        );
613
      }
614

615
      // process inputs
616
      for (let i = 0; i < inputCount; i++) {
617
        const index = inputIndices[i];
618
        const errorCode = await wasm._OrtBindInput(handle, inputNamesUTF8Encoded[index], inputTensorHandles[i]);
619
        if (errorCode !== 0) {
620
          checkLastError(`Can't bind input[${i}] for session=${sessionId}.`);
621
        }
622
      }
623

624
      // process pre-allocated outputs
625
      for (let i = 0; i < outputCount; i++) {
626
        const index = outputIndices[i];
627
        const location = outputTensors[i]?.[3]; // undefined means output is not pre-allocated.
628

629
        if (location) {
630
          // output is pre-allocated. bind the tensor.
631
          const errorCode = wasm._OrtBindOutput(handle, outputNamesUTF8Encoded[index], outputTensorHandles[i], 0);
632
          if (errorCode !== 0) {
633
            checkLastError(`Can't bind pre-allocated output[${i}] for session=${sessionId}.`);
634
          }
635
        } else {
636
          // output is not pre-allocated. reset preferred location.
637
          const errorCode = wasm._OrtBindOutput(
638
            handle,
639
            outputNamesUTF8Encoded[index],
640
            0,
641
            outputPreferredLocationsEncoded[index],
642
          );
643
          if (errorCode !== 0) {
644
            checkLastError(`Can't bind output[${i}] to ${outputPreferredLocations[i]} for session=${sessionId}.`);
645
          }
646
        }
647
      }
648
      activeSessions.set(sessionId, [
649
        sessionHandle,
650
        inputNamesUTF8Encoded,
651
        outputNamesUTF8Encoded,
652
        ioBindingState,
653
        enableGraphCapture,
654
        true,
655
      ]);
656
    }
657

658
    wasm.jsepOnRunStart?.(sessionHandle);
659
    let errorCode: number;
660
    if (!BUILD_DEFS.DISABLE_JSEP && ioBindingState) {
661
      errorCode = await wasm._OrtRunWithBinding(
662
        sessionHandle,
663
        ioBindingState.handle,
664
        outputCount,
665
        outputValuesOffset,
666
        runOptionsHandle,
667
      );
668
    } else {
669
      errorCode = await wasm._OrtRun(
670
        sessionHandle,
671
        inputNamesOffset,
672
        inputValuesOffset,
673
        inputCount,
674
        outputNamesOffset,
675
        outputCount,
676
        outputValuesOffset,
677
        runOptionsHandle,
678
      );
679
    }
680

681
    if (errorCode !== 0) {
682
      checkLastError('failed to call OrtRun().');
683
    }
684

685
    const output: TensorMetadata[] = [];
686

687
    for (let i = 0; i < outputCount; i++) {
688
      const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i];
689
      if (tensor === outputTensorHandles[i]) {
690
        // output tensor is pre-allocated. no need to copy data.
691
        output.push(outputTensors[i]!);
692
        continue;
693
      }
694

695
      const beforeGetTensorDataStack = wasm.stackSave();
696
      // stack allocate 4 pointer value
697
      const tensorDataOffset = wasm.stackAlloc(4 * 4);
698

699
      let keepOutputTensor = false;
700
      let type: Tensor.Type | undefined,
701
        dataOffset = 0;
702
      try {
703
        const errorCode = wasm._OrtGetTensorData(
704
          tensor,
705
          tensorDataOffset,
706
          tensorDataOffset + 4,
707
          tensorDataOffset + 8,
708
          tensorDataOffset + 12,
709
        );
710
        if (errorCode !== 0) {
711
          checkLastError(`Can't access output tensor data on index ${i}.`);
712
        }
713
        let tensorDataIndex = tensorDataOffset / 4;
714
        const dataType = wasm.HEAPU32[tensorDataIndex++];
715
        dataOffset = wasm.HEAPU32[tensorDataIndex++];
716
        const dimsOffset = wasm.HEAPU32[tensorDataIndex++];
717
        const dimsLength = wasm.HEAPU32[tensorDataIndex++];
718
        const dims = [];
719
        for (let i = 0; i < dimsLength; i++) {
720
          dims.push(wasm.HEAPU32[dimsOffset / 4 + i]);
721
        }
722
        wasm._OrtFree(dimsOffset);
723

724
        const size = dims.reduce((a, b) => a * b, 1);
725
        type = tensorDataTypeEnumToString(dataType);
726

727
        const preferredLocation = ioBindingState?.outputPreferredLocations[outputIndices[i]];
728

729
        if (type === 'string') {
730
          if (preferredLocation === 'gpu-buffer') {
731
            throw new Error('String tensor is not supported on GPU.');
732
          }
733
          const stringData: string[] = [];
734
          let dataIndex = dataOffset / 4;
735
          for (let i = 0; i < size; i++) {
736
            const offset = wasm.HEAPU32[dataIndex++];
737
            const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset;
738
            stringData.push(wasm.UTF8ToString(offset, maxBytesToRead));
739
          }
740
          output.push([type, dims, stringData, 'cpu']);
741
        } else {
742
          // If a certain output's preferred location is GPU but the tensor is empty, we still need to create a CPU
743
          // tensor for it. There is no mapping GPU buffer for an empty tensor.
744
          if (preferredLocation === 'gpu-buffer' && size > 0) {
745
            const getBuffer = wasm.jsepGetBuffer;
746
            if (!getBuffer) {
747
              throw new Error('preferredLocation "gpu-buffer" is not supported without using WebGPU.');
748
            }
749
            const gpuBuffer = getBuffer(dataOffset);
750
            const bufferSize = calculateTensorSizeInBytes(dataType, size);
751
            if (bufferSize === undefined || !isGpuBufferSupportedType(type)) {
752
              throw new Error(`Unsupported data type: ${type}`);
753
            }
754

755
            // do not release the tensor right now. it will be released when user calls tensor.dispose().
756
            keepOutputTensor = true;
757

758
            output.push([
759
              type,
760
              dims,
761
              {
762
                gpuBuffer,
763
                download: wasm.jsepCreateDownloader!(gpuBuffer, bufferSize, type),
764
                dispose: () => {
765
                  wasm._OrtReleaseTensor(tensor);
766
                },
767
              },
768
              'gpu-buffer',
769
            ]);
770
          } else {
771
            const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type);
772
            const data = new typedArrayConstructor(size);
773
            new Uint8Array(data.buffer, data.byteOffset, data.byteLength).set(
774
              wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength),
775
            );
776
            output.push([type, dims, data, 'cpu']);
777
          }
778
        }
779
      } finally {
780
        wasm.stackRestore(beforeGetTensorDataStack);
781
        if (type === 'string' && dataOffset) {
782
          wasm._free(dataOffset);
783
        }
784
        if (!keepOutputTensor) {
785
          wasm._OrtReleaseTensor(tensor);
786
        }
787
      }
788
    }
789

790
    if (ioBindingState && !enableGraphCapture) {
791
      wasm._OrtClearBoundOutputs(ioBindingState.handle);
792
      activeSessions.set(sessionId, [
793
        sessionHandle,
794
        inputNamesUTF8Encoded,
795
        outputNamesUTF8Encoded,
796
        ioBindingState,
797
        enableGraphCapture,
798
        false,
799
      ]);
800
    }
801
    return output;
802
  } finally {
803
    wasm.stackRestore(beforeRunStack);
804

805
    inputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v));
806
    outputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v));
807
    inputOutputAllocs.forEach((p) => wasm._free(p));
808

809
    if (runOptionsHandle !== 0) {
810
      wasm._OrtReleaseRunOptions(runOptionsHandle);
811
    }
812
    runOptionsAllocs.forEach((p) => wasm._free(p));
813
  }
814
};
815

816
/**
817
 * end profiling
818
 */
819
export const endProfiling = (sessionId: number): void => {
820
  const wasm = getInstance();
821
  const session = activeSessions.get(sessionId);
822
  if (!session) {
823
    throw new Error('invalid session id');
824
  }
825
  const sessionHandle = session[0];
826

827
  // profile file name is not used yet, but it must be freed.
828
  const profileFileName = wasm._OrtEndProfiling(sessionHandle);
829
  if (profileFileName === 0) {
830
    checkLastError("Can't get an profile file name.");
831
  }
832
  wasm._OrtFree(profileFileName);
833
};
834

835
export const extractTransferableBuffers = (tensors: readonly SerializableTensorMetadata[]): ArrayBufferLike[] => {
836
  const buffers: ArrayBufferLike[] = [];
837
  for (const tensor of tensors) {
838
    const data = tensor[2];
839
    if (!Array.isArray(data) && 'buffer' in data) {
840
      buffers.push(data.buffer);
841
    }
842
  }
843
  return buffers;
844
};
845

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.