onnxruntime
844 строки · 29.6 Кб
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4// WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from
5// WebNN API specification.
6// https://github.com/webmachinelearning/webnn/issues/677
7/// <reference path="jsep/webnn/webnn.d.ts" />
8
9import { Env, InferenceSession, Tensor } from 'onnxruntime-common';10
11import {12SerializableInternalBuffer,13SerializableSessionMetadata,14SerializableTensorMetadata,15TensorMetadata,16} from './proxy-messages';17import { setRunOptions } from './run-options';18import { setSessionOptions } from './session-options';19import {20calculateTensorSizeInBytes,21dataLocationStringToEnum,22isGpuBufferSupportedType,23logLevelStringToEnum,24tensorDataTypeEnumToString,25tensorDataTypeStringToEnum,26tensorTypeToTypedArrayConstructor,27} from './wasm-common';28import { getInstance } from './wasm-factory';29import { allocWasmString, checkLastError } from './wasm-utils';30import { loadFile } from './wasm-utils-load-file';31
32// #region Initializations
33
34/**
35* There are 4 different "initialization" steps for ORT. They happen in different places and different time.
36*
37* 1. JavaScript initialization for onnxruntime-common and onnxruntime-web.
38* This is the first initialization step. In this step, onnxruntime-web calls onnxruntime-common's registerBackend()
39* function multiple times to register all the available backends. The backend registration is very fast. It only
40* registers the backend name with the uninitialized backend object. No heavy initialization is done in this step.
41* Refer to web/lib/index.ts for the backend registration.
42*
43* 2. WebAssembly artifact initialization.
44* This happens when any registered wasm backend is used for the first time (ie. `ort.InferenceSession.create()` or
45* `ort.TrainingSession.create()` is called). In this step, onnxruntime-web does the followings:
46* - create a proxy worker and make sure the proxy worker is ready to receive messages, if proxy is enabled.
47* - perform feature detection, locate correct WebAssembly artifact path and call the Emscripten generated
48* JavaScript code to initialize the WebAssembly runtime.
49* - if proxy is enabled, this step happens in the proxy worker using message 'init-wasm'.
50* - downloading the 'ort-wasm{...}.wasm' file is done in this step.
51* - if multi-thread is enabled, one or more webworker will be created to initialize the PThread threadpool.
52*
53* 3. ORT environment initialization.
54* This happens after step 2. In this step, onnxruntime-web performs ONNX Runtime environment initialization.
55* Function `_OrtInit()` is called in this step.
56* - if proxy is enabled, this step happens in the proxy worker using message 'init-ort'.
57* - logging level (ort.env.logLevel) and thread number (ort.env.wasm.numThreads) are set in this step.
58*
59* 4. Session initialization.
60* This happens when `ort.InferenceSession.create()` or `ort.TrainingSession.create()` is called. Unlike the first 3
61* steps (they only called once), this step will be done for each session. In this step, onnxruntime-web does the
62* followings:
63* If the parameter is a URL:
64* - download the model data from the URL.
65* - copy the model data to the WASM heap. (proxy: 'copy-from')
66* - dereference the model buffer. This step allows the original ArrayBuffer to be garbage collected.
67* - call `_OrtCreateSession()` to create the session. (proxy: 'create')
68*
69* If the parameter is a Uint8Array object:
70* - copy the model data to the WASM heap. (proxy: 'copy-from')
71* - call `_OrtCreateSession()` to create the session. (proxy: 'create')
72*
73*
74*/
75
76/**
77* initialize ORT environment.
78*
79* @param numThreads SetGlobalIntraOpNumThreads(numThreads)
80* @param loggingLevel CreateEnv(static_cast<OrtLoggingLevel>(logging_level))
81*/
82const initOrt = (numThreads: number, loggingLevel: number): void => {83const errorCode = getInstance()._OrtInit(numThreads, loggingLevel);84if (errorCode !== 0) {85checkLastError("Can't initialize onnxruntime.");86}87};88
89/**
90* initialize runtime environment.
91* @param env passed in the environment config object.
92*/
93export const initRuntime = async (env: Env): Promise<void> => {94// init ORT95initOrt(env.wasm.numThreads!, logLevelStringToEnum(env.logLevel));96};97
98/**
99* perform EP specific initialization.
100*
101* @param env
102* @param epName
103*/
104export const initEp = async (env: Env, epName: string): Promise<void> => {105if (!BUILD_DEFS.DISABLE_JSEP) {106// eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires107const initJsep = require('./jsep/init').init;108
109if (epName === 'webgpu') {110// perform WebGPU availability check111if (typeof navigator === 'undefined' || !navigator.gpu) {112throw new Error('WebGPU is not supported in current environment');113}114
115let adapter = env.webgpu.adapter as GPUAdapter | null;116if (!adapter) {117// if adapter is not set, request a new adapter.118const powerPreference = env.webgpu.powerPreference;119if (120powerPreference !== undefined &&121powerPreference !== 'low-power' &&122powerPreference !== 'high-performance'123) {124throw new Error(`Invalid powerPreference setting: "${powerPreference}"`);125}126const forceFallbackAdapter = env.webgpu.forceFallbackAdapter;127if (forceFallbackAdapter !== undefined && typeof forceFallbackAdapter !== 'boolean') {128throw new Error(`Invalid forceFallbackAdapter setting: "${forceFallbackAdapter}"`);129}130adapter = await navigator.gpu.requestAdapter({ powerPreference, forceFallbackAdapter });131if (!adapter) {132throw new Error(133'Failed to get GPU adapter. ' +134'You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.',135);136}137} else {138// if adapter is set, validate it.139if (140typeof adapter.limits !== 'object' ||141typeof adapter.features !== 'object' ||142typeof adapter.requestDevice !== 'function'143) {144throw new Error('Invalid GPU adapter set in `env.webgpu.adapter`. It must be a GPUAdapter object.');145}146}147
148await initJsep('webgpu', getInstance(), env, adapter);149}150if (epName === 'webnn') {151// perform WebNN availability check152if (typeof navigator === 'undefined' || !(navigator as unknown as { ml: unknown }).ml) {153throw new Error('WebNN is not supported in current environment');154}155
156await initJsep('webnn', getInstance(), env);157}158}159};160
161// #endregion Initializations
162
163/**
164* valid data locations for input/output tensors.
165*/
166type SupportedTensorDataLocationForInputOutput = 'cpu' | 'cpu-pinned' | 'gpu-buffer';167
168type IOBindingState = {169/**170* the handle of IO binding.
171*/
172readonly handle: number;173
174/**175* the preferred location for each output tensor.
176*
177* value is one of 'cpu', 'cpu-pinned', 'gpu-buffer'.
178*/
179readonly outputPreferredLocations: readonly SupportedTensorDataLocationForInputOutput[];180
181/**182* enum value of the preferred location for each output tensor.
183*/
184readonly outputPreferredLocationsEncoded: readonly number[];185};186
187/**
188* tuple elements are: InferenceSession ID; inputNamesUTF8Encoded; outputNamesUTF8Encoded; bindingState
189*/
190type SessionMetadata = [191inferenceSessionId: number,192inputNamesUTF8Encoded: number[],193outputNamesUTF8Encoded: number[],194bindingState: IOBindingState | null,195enableGraphCapture: boolean,196inputOutputBound: boolean,197];198
199const activeSessions = new Map<number, SessionMetadata>();200
201/**
202* get the input/output count of the session.
203* @param sessionHandle the handle representing the session. should be non-zero.
204* @returns a tuple including 2 numbers, representing the input count and output count.
205*/
206const getSessionInputOutputCount = (sessionHandle: number): [number, number] => {207const wasm = getInstance();208const stack = wasm.stackSave();209try {210const dataOffset = wasm.stackAlloc(8);211const errorCode = wasm._OrtGetInputOutputCount(sessionHandle, dataOffset, dataOffset + 4);212if (errorCode !== 0) {213checkLastError("Can't get session input/output count.");214}215return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]];216} finally {217wasm.stackRestore(stack);218}219};220
221/**
222* allocate the memory and memcpy the external buffer.
223*
224* @param model - the external buffer containing the model data. Must not be the same buffer as the WASM heap.
225* @returns a 2-elements tuple - the pointer and size of the allocated buffer
226*/
227export const copyFromExternalBuffer = (model: Uint8Array): [number, number] => {228const wasm = getInstance();229const modelDataOffset = wasm._malloc(model.byteLength);230if (modelDataOffset === 0) {231throw new Error(`Can't create a session. failed to allocate a buffer of size ${model.byteLength}.`);232}233wasm.HEAPU8.set(model, modelDataOffset);234return [modelDataOffset, model.byteLength];235};236
237/**
238* create an inference session from a model data buffer.
239*
240* @param modelData - either a Uint8Array object representing the model data, or a 2-elements tuple containing the
241* pointer and size of the model data buffer.
242* @param options an optional session options object.
243* @returns a 3-elements tuple containing [session handle, input names, output names]
244*/
245export const createSession = async (246modelData: Uint8Array | SerializableInternalBuffer,247options?: InferenceSession.SessionOptions,248): Promise<SerializableSessionMetadata> => {249let modelDataOffset: number, modelDataLength: number;250const wasm = getInstance();251
252if (Array.isArray(modelData)) {253// if model data is an array, it must be a 2-elements tuple containing the pointer and size of the model data254[modelDataOffset, modelDataLength] = modelData;255} else if (modelData.buffer === wasm.HEAPU8.buffer) {256// if model data uses the same buffer as the WASM heap, we don't need to copy it.257[modelDataOffset, modelDataLength] = [modelData.byteOffset, modelData.byteLength];258} else {259// otherwise, copy the model data to the WASM heap.260[modelDataOffset, modelDataLength] = copyFromExternalBuffer(modelData);261}262
263let sessionHandle = 0;264let sessionOptionsHandle = 0;265let ioBindingHandle = 0;266let allocs: number[] = [];267const inputNamesUTF8Encoded = [];268const outputNamesUTF8Encoded = [];269
270try {271[sessionOptionsHandle, allocs] = setSessionOptions(options);272
273if (options?.externalData && wasm.mountExternalData) {274const loadingPromises = [];275for (const file of options.externalData) {276const path = typeof file === 'string' ? file : file.path;277loadingPromises.push(278loadFile(typeof file === 'string' ? file : file.data).then((data) => {279wasm.mountExternalData!(path, data);280}),281);282}283
284// wait for all external data files to be loaded285await Promise.all(loadingPromises);286}287
288for (const provider of options?.executionProviders ?? []) {289const providerName = typeof provider === 'string' ? provider : provider.name;290if (providerName === 'webnn') {291if (wasm.currentContext) {292throw new Error('WebNN execution provider is already set.');293}294if (typeof provider !== 'string') {295const webnnOptions = provider as InferenceSession.WebNNExecutionProviderOption;296const context = (webnnOptions as InferenceSession.WebNNOptionsWithMLContext)?.context;297const gpuDevice = (webnnOptions as InferenceSession.WebNNOptionsWebGpu)?.gpuDevice;298const deviceType = (webnnOptions as InferenceSession.WebNNContextOptions)?.deviceType;299const numThreads = (webnnOptions as InferenceSession.WebNNContextOptions)?.numThreads;300const powerPreference = (webnnOptions as InferenceSession.WebNNContextOptions)?.powerPreference;301if (context) {302wasm.currentContext = context as MLContext;303} else if (gpuDevice) {304wasm.currentContext = await navigator.ml.createContext(gpuDevice);305} else {306wasm.currentContext = await navigator.ml.createContext({ deviceType, numThreads, powerPreference });307}308} else {309wasm.currentContext = await navigator.ml.createContext();310}311break;312}313}314
315sessionHandle = await wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle);316if (sessionHandle === 0) {317checkLastError("Can't create a session.");318}319
320// clear current MLContext after session creation321if (wasm.currentContext) {322wasm.currentContext = undefined;323}324
325const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle);326
327const enableGraphCapture = !!options?.enableGraphCapture;328
329const inputNames = [];330const outputNames = [];331const outputPreferredLocations: SupportedTensorDataLocationForInputOutput[] = [];332for (let i = 0; i < inputCount; i++) {333const name = wasm._OrtGetInputName(sessionHandle, i);334if (name === 0) {335checkLastError("Can't get an input name.");336}337inputNamesUTF8Encoded.push(name);338inputNames.push(wasm.UTF8ToString(name));339}340for (let i = 0; i < outputCount; i++) {341const name = wasm._OrtGetOutputName(sessionHandle, i);342if (name === 0) {343checkLastError("Can't get an output name.");344}345outputNamesUTF8Encoded.push(name);346const nameString = wasm.UTF8ToString(name);347outputNames.push(nameString);348
349if (!BUILD_DEFS.DISABLE_JSEP) {350if (enableGraphCapture && options?.preferredOutputLocation === undefined) {351outputPreferredLocations.push('gpu-buffer');352continue;353}354const location =355typeof options?.preferredOutputLocation === 'string'356? options.preferredOutputLocation357: (options?.preferredOutputLocation?.[nameString] ?? 'cpu');358if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer') {359throw new Error(`Not supported preferred output location: ${location}.`);360}361if (enableGraphCapture && location !== 'gpu-buffer') {362throw new Error(363`Not supported preferred output location: ${location}. Only 'gpu-buffer' location is supported when enableGraphCapture is true.`,364);365}366outputPreferredLocations.push(location);367}368}369
370// use IO binding only when at least one output is preffered to be on GPU.371let bindingState: IOBindingState | null = null;372if (!BUILD_DEFS.DISABLE_JSEP && outputPreferredLocations.some((l) => l === 'gpu-buffer')) {373ioBindingHandle = wasm._OrtCreateBinding(sessionHandle);374if (ioBindingHandle === 0) {375checkLastError("Can't create IO binding.");376}377
378bindingState = {379handle: ioBindingHandle,380outputPreferredLocations,381outputPreferredLocationsEncoded: outputPreferredLocations.map((l) => dataLocationStringToEnum(l)),382};383}384
385activeSessions.set(sessionHandle, [386sessionHandle,387inputNamesUTF8Encoded,388outputNamesUTF8Encoded,389bindingState,390enableGraphCapture,391false,392]);393return [sessionHandle, inputNames, outputNames];394} catch (e) {395inputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf));396outputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf));397
398if (ioBindingHandle !== 0) {399wasm._OrtReleaseBinding(ioBindingHandle);400}401
402if (sessionHandle !== 0) {403wasm._OrtReleaseSession(sessionHandle);404}405throw e;406} finally {407wasm._free(modelDataOffset);408if (sessionOptionsHandle !== 0) {409wasm._OrtReleaseSessionOptions(sessionOptionsHandle);410}411allocs.forEach((alloc) => wasm._free(alloc));412
413// unmount external data if necessary414wasm.unmountExternalData?.();415}416};417
418export const releaseSession = (sessionId: number): void => {419const wasm = getInstance();420const session = activeSessions.get(sessionId);421if (!session) {422throw new Error(`cannot release session. invalid session id: ${sessionId}`);423}424const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture] = session;425
426if (ioBindingState) {427if (enableGraphCapture) {428wasm._OrtClearBoundOutputs(ioBindingState.handle);429}430wasm._OrtReleaseBinding(ioBindingState.handle);431}432
433wasm.jsepOnReleaseSession?.(sessionId);434
435inputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf));436outputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf));437wasm._OrtReleaseSession(sessionHandle);438activeSessions.delete(sessionId);439};440
441export const prepareInputOutputTensor = (442tensor: TensorMetadata | null,443tensorHandles: number[],444allocs: number[],445sessionId: number,446index: number,447enableGraphCapture = false,448): void => {449if (!tensor) {450tensorHandles.push(0);451return;452}453
454const wasm = getInstance();455
456const dataType = tensor[0];457const dims = tensor[1];458const location = tensor[3];459
460let rawData: number;461let dataByteLength: number;462
463if (dataType === 'string' && location === 'gpu-buffer') {464throw new Error('String tensor is not supported on GPU.');465}466
467if (enableGraphCapture && location !== 'gpu-buffer') {468throw new Error(469`External buffer must be provided for input/output index ${index} when enableGraphCapture is true.`,470);471}472
473if (location === 'gpu-buffer') {474const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer;475dataByteLength = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(dataType), dims)!;476
477const registerBuffer = wasm.jsepRegisterBuffer;478if (!registerBuffer) {479throw new Error('Tensor location "gpu-buffer" is not supported without using WebGPU.');480}481rawData = registerBuffer(sessionId, index, gpuBuffer, dataByteLength);482} else {483const data = tensor[2];484
485if (Array.isArray(data)) {486// string tensor487dataByteLength = 4 * data.length;488rawData = wasm._malloc(dataByteLength);489allocs.push(rawData);490let dataIndex = rawData / 4;491for (let i = 0; i < data.length; i++) {492if (typeof data[i] !== 'string') {493throw new TypeError(`tensor data at index ${i} is not a string`);494}495wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs);496}497} else {498dataByteLength = data.byteLength;499rawData = wasm._malloc(dataByteLength);500allocs.push(rawData);501wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData);502}503}504
505const stack = wasm.stackSave();506const dimsOffset = wasm.stackAlloc(4 * dims.length);507try {508let dimIndex = dimsOffset / 4;509dims.forEach((d) => (wasm.HEAP32[dimIndex++] = d));510const tensor = wasm._OrtCreateTensor(511tensorDataTypeStringToEnum(dataType),512rawData,513dataByteLength,514dimsOffset,515dims.length,516dataLocationStringToEnum(location),517);518if (tensor === 0) {519checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`);520}521tensorHandles.push(tensor);522} finally {523wasm.stackRestore(stack);524}525};526
527/**
528* perform inference run
529*/
530export const run = async (531sessionId: number,532inputIndices: number[],533inputTensors: TensorMetadata[],534outputIndices: number[],535outputTensors: Array<TensorMetadata | null>,536options: InferenceSession.RunOptions,537): Promise<TensorMetadata[]> => {538const wasm = getInstance();539const session = activeSessions.get(sessionId);540if (!session) {541throw new Error(`cannot run inference. invalid session id: ${sessionId}`);542}543const sessionHandle = session[0];544const inputNamesUTF8Encoded = session[1];545const outputNamesUTF8Encoded = session[2];546const ioBindingState = session[3];547const enableGraphCapture = session[4];548const inputOutputBound = session[5];549
550const inputCount = inputIndices.length;551const outputCount = outputIndices.length;552
553let runOptionsHandle = 0;554let runOptionsAllocs: number[] = [];555
556const inputTensorHandles: number[] = [];557const outputTensorHandles: number[] = [];558const inputOutputAllocs: number[] = [];559
560const beforeRunStack = wasm.stackSave();561const inputValuesOffset = wasm.stackAlloc(inputCount * 4);562const inputNamesOffset = wasm.stackAlloc(inputCount * 4);563const outputValuesOffset = wasm.stackAlloc(outputCount * 4);564const outputNamesOffset = wasm.stackAlloc(outputCount * 4);565
566try {567[runOptionsHandle, runOptionsAllocs] = setRunOptions(options);568
569// create input tensors570for (let i = 0; i < inputCount; i++) {571prepareInputOutputTensor(572inputTensors[i],573inputTensorHandles,574inputOutputAllocs,575sessionId,576inputIndices[i],577enableGraphCapture,578);579}580
581// create output tensors582for (let i = 0; i < outputCount; i++) {583prepareInputOutputTensor(584outputTensors[i],585outputTensorHandles,586inputOutputAllocs,587sessionId,588inputCount + outputIndices[i],589enableGraphCapture,590);591}592
593let inputValuesIndex = inputValuesOffset / 4;594let inputNamesIndex = inputNamesOffset / 4;595let outputValuesIndex = outputValuesOffset / 4;596let outputNamesIndex = outputNamesOffset / 4;597for (let i = 0; i < inputCount; i++) {598wasm.HEAPU32[inputValuesIndex++] = inputTensorHandles[i];599wasm.HEAPU32[inputNamesIndex++] = inputNamesUTF8Encoded[inputIndices[i]];600}601for (let i = 0; i < outputCount; i++) {602wasm.HEAPU32[outputValuesIndex++] = outputTensorHandles[i];603wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]];604}605
606if (!BUILD_DEFS.DISABLE_JSEP && ioBindingState && !inputOutputBound) {607const { handle, outputPreferredLocations, outputPreferredLocationsEncoded } = ioBindingState;608
609if (inputNamesUTF8Encoded.length !== inputCount) {610throw new Error(611`input count from feeds (${inputCount}) is expected to be always equal to model's input count (${inputNamesUTF8Encoded.length}).`,612);613}614
615// process inputs616for (let i = 0; i < inputCount; i++) {617const index = inputIndices[i];618const errorCode = await wasm._OrtBindInput(handle, inputNamesUTF8Encoded[index], inputTensorHandles[i]);619if (errorCode !== 0) {620checkLastError(`Can't bind input[${i}] for session=${sessionId}.`);621}622}623
624// process pre-allocated outputs625for (let i = 0; i < outputCount; i++) {626const index = outputIndices[i];627const location = outputTensors[i]?.[3]; // undefined means output is not pre-allocated.628
629if (location) {630// output is pre-allocated. bind the tensor.631const errorCode = wasm._OrtBindOutput(handle, outputNamesUTF8Encoded[index], outputTensorHandles[i], 0);632if (errorCode !== 0) {633checkLastError(`Can't bind pre-allocated output[${i}] for session=${sessionId}.`);634}635} else {636// output is not pre-allocated. reset preferred location.637const errorCode = wasm._OrtBindOutput(638handle,639outputNamesUTF8Encoded[index],6400,641outputPreferredLocationsEncoded[index],642);643if (errorCode !== 0) {644checkLastError(`Can't bind output[${i}] to ${outputPreferredLocations[i]} for session=${sessionId}.`);645}646}647}648activeSessions.set(sessionId, [649sessionHandle,650inputNamesUTF8Encoded,651outputNamesUTF8Encoded,652ioBindingState,653enableGraphCapture,654true,655]);656}657
658wasm.jsepOnRunStart?.(sessionHandle);659let errorCode: number;660if (!BUILD_DEFS.DISABLE_JSEP && ioBindingState) {661errorCode = await wasm._OrtRunWithBinding(662sessionHandle,663ioBindingState.handle,664outputCount,665outputValuesOffset,666runOptionsHandle,667);668} else {669errorCode = await wasm._OrtRun(670sessionHandle,671inputNamesOffset,672inputValuesOffset,673inputCount,674outputNamesOffset,675outputCount,676outputValuesOffset,677runOptionsHandle,678);679}680
681if (errorCode !== 0) {682checkLastError('failed to call OrtRun().');683}684
685const output: TensorMetadata[] = [];686
687for (let i = 0; i < outputCount; i++) {688const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i];689if (tensor === outputTensorHandles[i]) {690// output tensor is pre-allocated. no need to copy data.691output.push(outputTensors[i]!);692continue;693}694
695const beforeGetTensorDataStack = wasm.stackSave();696// stack allocate 4 pointer value697const tensorDataOffset = wasm.stackAlloc(4 * 4);698
699let keepOutputTensor = false;700let type: Tensor.Type | undefined,701dataOffset = 0;702try {703const errorCode = wasm._OrtGetTensorData(704tensor,705tensorDataOffset,706tensorDataOffset + 4,707tensorDataOffset + 8,708tensorDataOffset + 12,709);710if (errorCode !== 0) {711checkLastError(`Can't access output tensor data on index ${i}.`);712}713let tensorDataIndex = tensorDataOffset / 4;714const dataType = wasm.HEAPU32[tensorDataIndex++];715dataOffset = wasm.HEAPU32[tensorDataIndex++];716const dimsOffset = wasm.HEAPU32[tensorDataIndex++];717const dimsLength = wasm.HEAPU32[tensorDataIndex++];718const dims = [];719for (let i = 0; i < dimsLength; i++) {720dims.push(wasm.HEAPU32[dimsOffset / 4 + i]);721}722wasm._OrtFree(dimsOffset);723
724const size = dims.reduce((a, b) => a * b, 1);725type = tensorDataTypeEnumToString(dataType);726
727const preferredLocation = ioBindingState?.outputPreferredLocations[outputIndices[i]];728
729if (type === 'string') {730if (preferredLocation === 'gpu-buffer') {731throw new Error('String tensor is not supported on GPU.');732}733const stringData: string[] = [];734let dataIndex = dataOffset / 4;735for (let i = 0; i < size; i++) {736const offset = wasm.HEAPU32[dataIndex++];737const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset;738stringData.push(wasm.UTF8ToString(offset, maxBytesToRead));739}740output.push([type, dims, stringData, 'cpu']);741} else {742// If a certain output's preferred location is GPU but the tensor is empty, we still need to create a CPU743// tensor for it. There is no mapping GPU buffer for an empty tensor.744if (preferredLocation === 'gpu-buffer' && size > 0) {745const getBuffer = wasm.jsepGetBuffer;746if (!getBuffer) {747throw new Error('preferredLocation "gpu-buffer" is not supported without using WebGPU.');748}749const gpuBuffer = getBuffer(dataOffset);750const bufferSize = calculateTensorSizeInBytes(dataType, size);751if (bufferSize === undefined || !isGpuBufferSupportedType(type)) {752throw new Error(`Unsupported data type: ${type}`);753}754
755// do not release the tensor right now. it will be released when user calls tensor.dispose().756keepOutputTensor = true;757
758output.push([759type,760dims,761{762gpuBuffer,763download: wasm.jsepCreateDownloader!(gpuBuffer, bufferSize, type),764dispose: () => {765wasm._OrtReleaseTensor(tensor);766},767},768'gpu-buffer',769]);770} else {771const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type);772const data = new typedArrayConstructor(size);773new Uint8Array(data.buffer, data.byteOffset, data.byteLength).set(774wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength),775);776output.push([type, dims, data, 'cpu']);777}778}779} finally {780wasm.stackRestore(beforeGetTensorDataStack);781if (type === 'string' && dataOffset) {782wasm._free(dataOffset);783}784if (!keepOutputTensor) {785wasm._OrtReleaseTensor(tensor);786}787}788}789
790if (ioBindingState && !enableGraphCapture) {791wasm._OrtClearBoundOutputs(ioBindingState.handle);792activeSessions.set(sessionId, [793sessionHandle,794inputNamesUTF8Encoded,795outputNamesUTF8Encoded,796ioBindingState,797enableGraphCapture,798false,799]);800}801return output;802} finally {803wasm.stackRestore(beforeRunStack);804
805inputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v));806outputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v));807inputOutputAllocs.forEach((p) => wasm._free(p));808
809if (runOptionsHandle !== 0) {810wasm._OrtReleaseRunOptions(runOptionsHandle);811}812runOptionsAllocs.forEach((p) => wasm._free(p));813}814};815
816/**
817* end profiling
818*/
819export const endProfiling = (sessionId: number): void => {820const wasm = getInstance();821const session = activeSessions.get(sessionId);822if (!session) {823throw new Error('invalid session id');824}825const sessionHandle = session[0];826
827// profile file name is not used yet, but it must be freed.828const profileFileName = wasm._OrtEndProfiling(sessionHandle);829if (profileFileName === 0) {830checkLastError("Can't get an profile file name.");831}832wasm._OrtFree(profileFileName);833};834
835export const extractTransferableBuffers = (tensors: readonly SerializableTensorMetadata[]): ArrayBufferLike[] => {836const buffers: ArrayBufferLike[] = [];837for (const tensor of tensors) {838const data = tensor[2];839if (!Array.isArray(data) && 'buffer' in data) {840buffers.push(data.buffer);841}842}843return buffers;844};845