onnxruntime
208 строк · 8.0 Кб
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4import { InferenceSession } from 'onnxruntime-common';5
6import { getInstance } from './wasm-factory';7import { allocWasmString, checkLastError, iterateExtraOptions } from './wasm-utils';8
9const getGraphOptimzationLevel = (graphOptimizationLevel: string | unknown): number => {10switch (graphOptimizationLevel) {11case 'disabled':12return 0;13case 'basic':14return 1;15case 'extended':16return 2;17case 'all':18return 99;19default:20throw new Error(`unsupported graph optimization level: ${graphOptimizationLevel}`);21}22};23
24const getExecutionMode = (executionMode: 'sequential' | 'parallel'): number => {25switch (executionMode) {26case 'sequential':27return 0;28case 'parallel':29return 1;30default:31throw new Error(`unsupported execution mode: ${executionMode}`);32}33};34
35const appendDefaultOptions = (options: InferenceSession.SessionOptions): void => {36if (!options.extra) {37options.extra = {};38}39if (!options.extra.session) {40options.extra.session = {};41}42const session = options.extra.session as Record<string, string>;43if (!session.use_ort_model_bytes_directly) {44// eslint-disable-next-line camelcase45session.use_ort_model_bytes_directly = '1';46}47
48// if using JSEP with WebGPU, always disable memory pattern49if (50options.executionProviders &&51options.executionProviders.some((ep) => (typeof ep === 'string' ? ep : ep.name) === 'webgpu')52) {53options.enableMemPattern = false;54}55};56
57const setExecutionProviders = (58sessionOptionsHandle: number,59executionProviders: readonly InferenceSession.ExecutionProviderConfig[],60allocs: number[],61): void => {62for (const ep of executionProviders) {63let epName = typeof ep === 'string' ? ep : ep.name;64
65// check EP name66switch (epName) {67case 'webnn':68epName = 'WEBNN';69if (typeof ep !== 'string') {70const webnnOptions = ep as InferenceSession.WebNNExecutionProviderOption;71// const context = (webnnOptions as InferenceSession.WebNNOptionsWithMLContext)?.context;72const deviceType = (webnnOptions as InferenceSession.WebNNContextOptions)?.deviceType;73if (deviceType) {74const keyDataOffset = allocWasmString('deviceType', allocs);75const valueDataOffset = allocWasmString(deviceType, allocs);76if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) {77checkLastError(`Can't set a session config entry: 'deviceType' - ${deviceType}.`);78}79}80}81break;82case 'webgpu':83epName = 'JS';84if (typeof ep !== 'string') {85const webgpuOptions = ep as InferenceSession.WebGpuExecutionProviderOption;86if (webgpuOptions?.preferredLayout) {87if (webgpuOptions.preferredLayout !== 'NCHW' && webgpuOptions.preferredLayout !== 'NHWC') {88throw new Error(`preferredLayout must be either 'NCHW' or 'NHWC': ${webgpuOptions.preferredLayout}`);89}90const keyDataOffset = allocWasmString('preferredLayout', allocs);91const valueDataOffset = allocWasmString(webgpuOptions.preferredLayout, allocs);92if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) {93checkLastError(`Can't set a session config entry: 'preferredLayout' - ${webgpuOptions.preferredLayout}.`);94}95}96}97break;98case 'wasm':99case 'cpu':100continue;101default:102throw new Error(`not supported execution provider: ${epName}`);103}104
105const epNameDataOffset = allocWasmString(epName, allocs);106if (getInstance()._OrtAppendExecutionProvider(sessionOptionsHandle, epNameDataOffset) !== 0) {107checkLastError(`Can't append execution provider: ${epName}.`);108}109}110};111
112export const setSessionOptions = (options?: InferenceSession.SessionOptions): [number, number[]] => {113const wasm = getInstance();114let sessionOptionsHandle = 0;115const allocs: number[] = [];116
117const sessionOptions: InferenceSession.SessionOptions = options || {};118appendDefaultOptions(sessionOptions);119
120try {121const graphOptimizationLevel = getGraphOptimzationLevel(sessionOptions.graphOptimizationLevel ?? 'all');122const executionMode = getExecutionMode(sessionOptions.executionMode ?? 'sequential');123const logIdDataOffset =124typeof sessionOptions.logId === 'string' ? allocWasmString(sessionOptions.logId, allocs) : 0;125
126const logSeverityLevel = sessionOptions.logSeverityLevel ?? 2; // Default to 2 - warning127if (!Number.isInteger(logSeverityLevel) || logSeverityLevel < 0 || logSeverityLevel > 4) {128throw new Error(`log serverity level is not valid: ${logSeverityLevel}`);129}130
131const logVerbosityLevel = sessionOptions.logVerbosityLevel ?? 0; // Default to 0 - verbose132if (!Number.isInteger(logVerbosityLevel) || logVerbosityLevel < 0 || logVerbosityLevel > 4) {133throw new Error(`log verbosity level is not valid: ${logVerbosityLevel}`);134}135
136const optimizedModelFilePathOffset =137typeof sessionOptions.optimizedModelFilePath === 'string'138? allocWasmString(sessionOptions.optimizedModelFilePath, allocs)139: 0;140
141sessionOptionsHandle = wasm._OrtCreateSessionOptions(142graphOptimizationLevel,143!!sessionOptions.enableCpuMemArena,144!!sessionOptions.enableMemPattern,145executionMode,146!!sessionOptions.enableProfiling,1470,148logIdDataOffset,149logSeverityLevel,150logVerbosityLevel,151optimizedModelFilePathOffset,152);153if (sessionOptionsHandle === 0) {154checkLastError("Can't create session options.");155}156
157if (sessionOptions.executionProviders) {158setExecutionProviders(sessionOptionsHandle, sessionOptions.executionProviders, allocs);159}160
161if (sessionOptions.enableGraphCapture !== undefined) {162if (typeof sessionOptions.enableGraphCapture !== 'boolean') {163throw new Error(`enableGraphCapture must be a boolean value: ${sessionOptions.enableGraphCapture}`);164}165const keyDataOffset = allocWasmString('enableGraphCapture', allocs);166const valueDataOffset = allocWasmString(sessionOptions.enableGraphCapture.toString(), allocs);167if (wasm._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) {168checkLastError(169`Can't set a session config entry: 'enableGraphCapture' - ${sessionOptions.enableGraphCapture}.`,170);171}172}173
174if (sessionOptions.freeDimensionOverrides) {175for (const [name, value] of Object.entries(sessionOptions.freeDimensionOverrides)) {176if (typeof name !== 'string') {177throw new Error(`free dimension override name must be a string: ${name}`);178}179if (typeof value !== 'number' || !Number.isInteger(value) || value < 0) {180throw new Error(`free dimension override value must be a non-negative integer: ${value}`);181}182const nameOffset = allocWasmString(name, allocs);183if (wasm._OrtAddFreeDimensionOverride(sessionOptionsHandle, nameOffset, value) !== 0) {184checkLastError(`Can't set a free dimension override: ${name} - ${value}.`);185}186}187}188
189if (sessionOptions.extra !== undefined) {190iterateExtraOptions(sessionOptions.extra, '', new WeakSet<Record<string, unknown>>(), (key, value) => {191const keyDataOffset = allocWasmString(key, allocs);192const valueDataOffset = allocWasmString(value, allocs);193
194if (wasm._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) {195checkLastError(`Can't set a session config entry: ${key} - ${value}.`);196}197});198}199
200return [sessionOptionsHandle, allocs];201} catch (e) {202if (sessionOptionsHandle !== 0) {203wasm._OrtReleaseSessionOptions(sessionOptionsHandle);204}205allocs.forEach((alloc) => wasm._free(alloc));206throw e;207}208};209