onnxruntime

Форк
0
/
session-options.ts 
208 строк · 8.0 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { InferenceSession } from 'onnxruntime-common';
5

6
import { getInstance } from './wasm-factory';
7
import { allocWasmString, checkLastError, iterateExtraOptions } from './wasm-utils';
8

9
const getGraphOptimzationLevel = (graphOptimizationLevel: string | unknown): number => {
10
  switch (graphOptimizationLevel) {
11
    case 'disabled':
12
      return 0;
13
    case 'basic':
14
      return 1;
15
    case 'extended':
16
      return 2;
17
    case 'all':
18
      return 99;
19
    default:
20
      throw new Error(`unsupported graph optimization level: ${graphOptimizationLevel}`);
21
  }
22
};
23

24
const getExecutionMode = (executionMode: 'sequential' | 'parallel'): number => {
25
  switch (executionMode) {
26
    case 'sequential':
27
      return 0;
28
    case 'parallel':
29
      return 1;
30
    default:
31
      throw new Error(`unsupported execution mode: ${executionMode}`);
32
  }
33
};
34

35
const appendDefaultOptions = (options: InferenceSession.SessionOptions): void => {
36
  if (!options.extra) {
37
    options.extra = {};
38
  }
39
  if (!options.extra.session) {
40
    options.extra.session = {};
41
  }
42
  const session = options.extra.session as Record<string, string>;
43
  if (!session.use_ort_model_bytes_directly) {
44
    // eslint-disable-next-line camelcase
45
    session.use_ort_model_bytes_directly = '1';
46
  }
47

48
  // if using JSEP with WebGPU, always disable memory pattern
49
  if (
50
    options.executionProviders &&
51
    options.executionProviders.some((ep) => (typeof ep === 'string' ? ep : ep.name) === 'webgpu')
52
  ) {
53
    options.enableMemPattern = false;
54
  }
55
};
56

57
const setExecutionProviders = (
58
  sessionOptionsHandle: number,
59
  executionProviders: readonly InferenceSession.ExecutionProviderConfig[],
60
  allocs: number[],
61
): void => {
62
  for (const ep of executionProviders) {
63
    let epName = typeof ep === 'string' ? ep : ep.name;
64

65
    // check EP name
66
    switch (epName) {
67
      case 'webnn':
68
        epName = 'WEBNN';
69
        if (typeof ep !== 'string') {
70
          const webnnOptions = ep as InferenceSession.WebNNExecutionProviderOption;
71
          // const context = (webnnOptions as InferenceSession.WebNNOptionsWithMLContext)?.context;
72
          const deviceType = (webnnOptions as InferenceSession.WebNNContextOptions)?.deviceType;
73
          if (deviceType) {
74
            const keyDataOffset = allocWasmString('deviceType', allocs);
75
            const valueDataOffset = allocWasmString(deviceType, allocs);
76
            if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) {
77
              checkLastError(`Can't set a session config entry: 'deviceType' - ${deviceType}.`);
78
            }
79
          }
80
        }
81
        break;
82
      case 'webgpu':
83
        epName = 'JS';
84
        if (typeof ep !== 'string') {
85
          const webgpuOptions = ep as InferenceSession.WebGpuExecutionProviderOption;
86
          if (webgpuOptions?.preferredLayout) {
87
            if (webgpuOptions.preferredLayout !== 'NCHW' && webgpuOptions.preferredLayout !== 'NHWC') {
88
              throw new Error(`preferredLayout must be either 'NCHW' or 'NHWC': ${webgpuOptions.preferredLayout}`);
89
            }
90
            const keyDataOffset = allocWasmString('preferredLayout', allocs);
91
            const valueDataOffset = allocWasmString(webgpuOptions.preferredLayout, allocs);
92
            if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) {
93
              checkLastError(`Can't set a session config entry: 'preferredLayout' - ${webgpuOptions.preferredLayout}.`);
94
            }
95
          }
96
        }
97
        break;
98
      case 'wasm':
99
      case 'cpu':
100
        continue;
101
      default:
102
        throw new Error(`not supported execution provider: ${epName}`);
103
    }
104

105
    const epNameDataOffset = allocWasmString(epName, allocs);
106
    if (getInstance()._OrtAppendExecutionProvider(sessionOptionsHandle, epNameDataOffset) !== 0) {
107
      checkLastError(`Can't append execution provider: ${epName}.`);
108
    }
109
  }
110
};
111

112
export const setSessionOptions = (options?: InferenceSession.SessionOptions): [number, number[]] => {
113
  const wasm = getInstance();
114
  let sessionOptionsHandle = 0;
115
  const allocs: number[] = [];
116

117
  const sessionOptions: InferenceSession.SessionOptions = options || {};
118
  appendDefaultOptions(sessionOptions);
119

120
  try {
121
    const graphOptimizationLevel = getGraphOptimzationLevel(sessionOptions.graphOptimizationLevel ?? 'all');
122
    const executionMode = getExecutionMode(sessionOptions.executionMode ?? 'sequential');
123
    const logIdDataOffset =
124
      typeof sessionOptions.logId === 'string' ? allocWasmString(sessionOptions.logId, allocs) : 0;
125

126
    const logSeverityLevel = sessionOptions.logSeverityLevel ?? 2; // Default to 2 - warning
127
    if (!Number.isInteger(logSeverityLevel) || logSeverityLevel < 0 || logSeverityLevel > 4) {
128
      throw new Error(`log serverity level is not valid: ${logSeverityLevel}`);
129
    }
130

131
    const logVerbosityLevel = sessionOptions.logVerbosityLevel ?? 0; // Default to 0 - verbose
132
    if (!Number.isInteger(logVerbosityLevel) || logVerbosityLevel < 0 || logVerbosityLevel > 4) {
133
      throw new Error(`log verbosity level is not valid: ${logVerbosityLevel}`);
134
    }
135

136
    const optimizedModelFilePathOffset =
137
      typeof sessionOptions.optimizedModelFilePath === 'string'
138
        ? allocWasmString(sessionOptions.optimizedModelFilePath, allocs)
139
        : 0;
140

141
    sessionOptionsHandle = wasm._OrtCreateSessionOptions(
142
      graphOptimizationLevel,
143
      !!sessionOptions.enableCpuMemArena,
144
      !!sessionOptions.enableMemPattern,
145
      executionMode,
146
      !!sessionOptions.enableProfiling,
147
      0,
148
      logIdDataOffset,
149
      logSeverityLevel,
150
      logVerbosityLevel,
151
      optimizedModelFilePathOffset,
152
    );
153
    if (sessionOptionsHandle === 0) {
154
      checkLastError("Can't create session options.");
155
    }
156

157
    if (sessionOptions.executionProviders) {
158
      setExecutionProviders(sessionOptionsHandle, sessionOptions.executionProviders, allocs);
159
    }
160

161
    if (sessionOptions.enableGraphCapture !== undefined) {
162
      if (typeof sessionOptions.enableGraphCapture !== 'boolean') {
163
        throw new Error(`enableGraphCapture must be a boolean value: ${sessionOptions.enableGraphCapture}`);
164
      }
165
      const keyDataOffset = allocWasmString('enableGraphCapture', allocs);
166
      const valueDataOffset = allocWasmString(sessionOptions.enableGraphCapture.toString(), allocs);
167
      if (wasm._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) {
168
        checkLastError(
169
          `Can't set a session config entry: 'enableGraphCapture' - ${sessionOptions.enableGraphCapture}.`,
170
        );
171
      }
172
    }
173

174
    if (sessionOptions.freeDimensionOverrides) {
175
      for (const [name, value] of Object.entries(sessionOptions.freeDimensionOverrides)) {
176
        if (typeof name !== 'string') {
177
          throw new Error(`free dimension override name must be a string: ${name}`);
178
        }
179
        if (typeof value !== 'number' || !Number.isInteger(value) || value < 0) {
180
          throw new Error(`free dimension override value must be a non-negative integer: ${value}`);
181
        }
182
        const nameOffset = allocWasmString(name, allocs);
183
        if (wasm._OrtAddFreeDimensionOverride(sessionOptionsHandle, nameOffset, value) !== 0) {
184
          checkLastError(`Can't set a free dimension override: ${name} - ${value}.`);
185
        }
186
      }
187
    }
188

189
    if (sessionOptions.extra !== undefined) {
190
      iterateExtraOptions(sessionOptions.extra, '', new WeakSet<Record<string, unknown>>(), (key, value) => {
191
        const keyDataOffset = allocWasmString(key, allocs);
192
        const valueDataOffset = allocWasmString(value, allocs);
193

194
        if (wasm._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) {
195
          checkLastError(`Can't set a session config entry: ${key} - ${value}.`);
196
        }
197
      });
198
    }
199

200
    return [sessionOptionsHandle, allocs];
201
  } catch (e) {
202
    if (sessionOptionsHandle !== 0) {
203
      wasm._OrtReleaseSessionOptions(sessionOptionsHandle);
204
    }
205
    allocs.forEach((alloc) => wasm._free(alloc));
206
    throw e;
207
  }
208
};
209

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.