onnxruntime
631 строка · 20.2 Кб
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4import { InferenceSession, Tensor } from 'onnxruntime-common';
5
6import { SerializableInternalBuffer, TensorMetadata } from './proxy-messages';
7import { setRunOptions } from './run-options';
8import { setSessionOptions } from './session-options';
9import {
10dataLocationStringToEnum,
11tensorDataTypeEnumToString,
12tensorDataTypeStringToEnum,
13tensorTypeToTypedArrayConstructor,
14} from './wasm-common';
15import { prepareInputOutputTensor } from './wasm-core-impl';
16import { getInstance } from './wasm-factory';
17import { checkLastError } from './wasm-utils';
18
19const NO_TRAIN_FUNCS_MSG =
20"Built without training API's enabled. Use the onnxruntime-web/training import for training " +
21'functionality, and make sure that all the correct artifacts are built & moved to the correct folder if ' +
22'using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.';
23
24/**
25* Runs the checkLastError function which will throw an error, if the provided error code matches the specified
26* pattern for an error code.
27* @param errCode number to evaluated for if it's an error
28* @param message message to pass into checkLastError
29* @param checkNeqZero when true, treats not equal to zero as an error.
30* When false, treats equal to zero as an error.
31*/
32const ifErrCodeCheckLastError = (errCode: number, message: string, checkNeqZero = true) => {
33if (checkNeqZero && errCode !== 0) {
34checkLastError(message);
35} else if (!checkNeqZero && errCode === 0) {
36checkLastError(message);
37}
38};
39
40export const createCheckpointHandle = (checkpointData: SerializableInternalBuffer): number => {
41const wasm = getInstance();
42
43const [checkpointDataOffset, checkpointDataLength] = checkpointData;
44let checkpointHandle = 0;
45
46try {
47if (wasm._OrtTrainingLoadCheckpoint) {
48checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointDataOffset, checkpointDataLength);
49} else {
50throw new Error(NO_TRAIN_FUNCS_MSG);
51}
52
53ifErrCodeCheckLastError(checkpointHandle, 'Error occurred when trying to create a CheckpointState', false);
54return checkpointHandle;
55} catch (e) {
56if (wasm._OrtTrainingReleaseCheckpoint && checkpointHandle !== 0) {
57wasm._OrtTrainingReleaseCheckpoint(checkpointHandle);
58}
59throw e;
60} finally {
61// free buffer from wasm heap
62wasm._OrtFree(checkpointData[0]);
63}
64};
65
66const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolean): [number, number] => {
67const wasm = getInstance();
68const stack = wasm.stackSave();
69try {
70const dataOffset = wasm.stackAlloc(8);
71if (wasm._OrtTrainingGetModelInputOutputCount) {
72const errorCode = wasm._OrtTrainingGetModelInputOutputCount(
73trainingSessionId,
74dataOffset,
75dataOffset + 4,
76isEvalModel,
77);
78ifErrCodeCheckLastError(errorCode, "Can't get session input/output count.");
79return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]];
80} else {
81throw new Error(NO_TRAIN_FUNCS_MSG);
82}
83} finally {
84wasm.stackRestore(stack);
85}
86};
87
88const getModelInputOutputNamesLoop = (
89trainingSessionId: number,
90count: number,
91isInput: boolean,
92isEvalModel: boolean,
93): string[] => {
94const names = [];
95const wasm = getInstance();
96
97for (let i = 0; i < count; i++) {
98if (wasm._OrtTrainingGetModelInputOutputName) {
99const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel);
100ifErrCodeCheckLastError(name, `Can't get input or output name -- is input: ${isInput}, index ${i}`, false);
101
102names.push(wasm.UTF8ToString(name));
103wasm._free(name);
104} else {
105throw new Error(NO_TRAIN_FUNCS_MSG);
106}
107}
108return names;
109};
110
111export const getModelInputOutputNames = (trainingSessionId: number, isEvalModel: boolean): [string[], string[]] => {
112let inputNames: string[] = [];
113let outputNames: string[] = [];
114
115const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, isEvalModel);
116
117inputNames = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, isEvalModel);
118outputNames = getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, isEvalModel);
119
120return [inputNames, outputNames];
121};
122
123export const createTrainingSessionHandle = (
124checkpointHandle: number,
125trainModelData: SerializableInternalBuffer,
126evalModelData: SerializableInternalBuffer,
127optimizerModelData: SerializableInternalBuffer,
128options: InferenceSession.SessionOptions,
129): number => {
130const wasm = getInstance();
131
132let trainingSessionHandle = 0;
133let sessionOptionsHandle = 0;
134let allocs: number[] = [];
135
136try {
137[sessionOptionsHandle, allocs] = setSessionOptions(options);
138if (wasm._OrtTrainingCreateSession) {
139trainingSessionHandle = wasm._OrtTrainingCreateSession(
140sessionOptionsHandle,
141checkpointHandle,
142trainModelData[0],
143trainModelData[1],
144evalModelData[0],
145evalModelData[1],
146optimizerModelData[0],
147optimizerModelData[1],
148);
149} else {
150throw new Error(NO_TRAIN_FUNCS_MSG);
151}
152
153ifErrCodeCheckLastError(trainingSessionHandle, 'Error occurred when trying to create a TrainingSession', false);
154return trainingSessionHandle;
155} catch (e) {
156if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) {
157wasm._OrtTrainingReleaseSession(trainingSessionHandle);
158}
159throw e;
160} finally {
161wasm._free(trainModelData[0]);
162wasm._free(evalModelData[0]);
163wasm._free(optimizerModelData[0]);
164
165if (sessionOptionsHandle !== 0) {
166wasm._OrtReleaseSessionOptions(sessionOptionsHandle);
167}
168allocs.forEach((alloc) => wasm._free(alloc));
169}
170};
171
172/**
173* Prepares input and output tensors by creating the tensors in the WASM side then creates a list of the handles of the
174* WASM tensors.
175*
176* @param trainingSessionId
177* @param indices for each tensor, the index of the input or output name that the tensor corresponds with
178* @param tensors list of TensorMetaData
179* @param tensorHandles should pass in an empty list of numbers; modified in-place by this method & stores the resulting
180* handles of the allocated tensors on the heap
181* @param inputOutputAllocs modified in-place by this method
182* @param indexAdd constant to add to the index that is passed to prepareInputOutputTensor
183*/
184const createAndAllocateTensors = (
185trainingSessionId: number,
186indices: number[],
187tensors: Array<TensorMetadata | null>,
188tensorHandles: number[],
189inputOutputAllocs: number[],
190indexAdd: number,
191) => {
192const count = indices.length;
193
194// creates the tensors
195for (let i = 0; i < count; i++) {
196prepareInputOutputTensor(tensors[i], tensorHandles, inputOutputAllocs, trainingSessionId, indexAdd + indices[i]);
197}
198
199// moves to heap
200const wasm = getInstance();
201const valuesOffset = wasm.stackAlloc(count * 4);
202let valuesIndex = valuesOffset / 4;
203for (let i = 0; i < count; i++) {
204wasm.HEAPU32[valuesIndex++] = tensorHandles[i];
205}
206
207return valuesOffset;
208};
209
210/**
211* Retrieves the information from the output tensor handles, copies to an array, and frees the WASM information
212* associated with the tensor handle.
213*
214* @param outputValuesOffset
215* @param outputCount
216* @returns list of TensorMetadata retrieved from the output handles.
217*/
218const moveOutputToTensorMetadataArr = (
219outputValuesOffset: number,
220outputCount: number,
221outputTensorHandles: number[],
222outputTensors: Array<TensorMetadata | null>,
223) => {
224const wasm = getInstance();
225const output: TensorMetadata[] = [];
226
227for (let i = 0; i < outputCount; i++) {
228const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i];
229if (tensor === outputTensorHandles[i]) {
230// output tensor is pre-allocated. no need to copy data.
231output.push(outputTensors[i]!);
232continue;
233}
234
235const beforeGetTensorDataStack = wasm.stackSave();
236// stack allocate 4 pointer value
237const tensorDataOffset = wasm.stackAlloc(4 * 4);
238
239let type: Tensor.Type | undefined,
240dataOffset = 0;
241try {
242const errorCode = wasm._OrtGetTensorData(
243tensor,
244tensorDataOffset,
245tensorDataOffset + 4,
246tensorDataOffset + 8,
247tensorDataOffset + 12,
248);
249ifErrCodeCheckLastError(errorCode, `Can't access output tensor data on index ${i}.`);
250
251let tensorDataIndex = tensorDataOffset / 4;
252const dataType = wasm.HEAPU32[tensorDataIndex++];
253dataOffset = wasm.HEAPU32[tensorDataIndex++];
254const dimsOffset = wasm.HEAPU32[tensorDataIndex++];
255const dimsLength = wasm.HEAPU32[tensorDataIndex++];
256const dims = [];
257for (let i = 0; i < dimsLength; i++) {
258dims.push(wasm.HEAPU32[dimsOffset / 4 + i]);
259}
260wasm._OrtFree(dimsOffset);
261
262const size = dims.reduce((a, b) => a * b, 1);
263type = tensorDataTypeEnumToString(dataType);
264
265if (type === 'string') {
266const stringData: string[] = [];
267let dataIndex = dataOffset / 4;
268for (let i = 0; i < size; i++) {
269const offset = wasm.HEAPU32[dataIndex++];
270const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset;
271stringData.push(wasm.UTF8ToString(offset, maxBytesToRead));
272}
273output.push([type, dims, stringData, 'cpu']);
274} else {
275const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type);
276const data = new typedArrayConstructor(size);
277new Uint8Array(data.buffer, data.byteOffset, data.byteLength).set(
278wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength),
279);
280output.push([type, dims, data, 'cpu']);
281}
282} finally {
283wasm.stackRestore(beforeGetTensorDataStack);
284if (type === 'string' && dataOffset) {
285wasm._free(dataOffset);
286}
287wasm._OrtReleaseTensor(tensor);
288}
289}
290
291return output;
292};
293
294export const lazyResetGrad = async (trainingSessionId: number): Promise<void> => {
295const wasm = getInstance();
296
297if (wasm._OrtTrainingLazyResetGrad) {
298const errorCode = wasm._OrtTrainingLazyResetGrad(trainingSessionId);
299ifErrCodeCheckLastError(errorCode, "Can't call lazyResetGrad.");
300} else {
301throw new Error(NO_TRAIN_FUNCS_MSG);
302}
303};
304
305export const runTrainStep = async (
306trainingSessionId: number,
307inputIndices: number[],
308inputTensors: TensorMetadata[],
309outputIndices: number[],
310outputTensors: Array<TensorMetadata | null>,
311options: InferenceSession.RunOptions,
312): Promise<TensorMetadata[]> => {
313const wasm = getInstance();
314
315const inputCount = inputIndices.length;
316const outputCount = outputIndices.length;
317
318let runOptionsHandle = 0;
319let runOptionsAllocs: number[] = [];
320
321const inputTensorHandles: number[] = [];
322const outputTensorHandles: number[] = [];
323const inputOutputAllocs: number[] = [];
324
325const beforeRunStack = wasm.stackSave();
326
327try {
328// prepare parameters by moving them to heap
329[runOptionsHandle, runOptionsAllocs] = setRunOptions(options);
330
331// handle inputs -- you don't want anything added to the index
332const inputValuesOffset = createAndAllocateTensors(
333trainingSessionId,
334inputIndices,
335inputTensors,
336inputTensorHandles,
337inputOutputAllocs,
3380,
339);
340// handle outputs
341// you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor
342const outputValuesOffset = createAndAllocateTensors(
343trainingSessionId,
344outputIndices,
345outputTensors,
346outputTensorHandles,
347inputOutputAllocs,
348inputCount,
349);
350
351if (wasm._OrtTrainingRunTrainStep) {
352const errorCode = wasm._OrtTrainingRunTrainStep(
353trainingSessionId,
354inputValuesOffset,
355inputCount,
356outputValuesOffset,
357outputCount,
358runOptionsHandle,
359);
360ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingRunTrainStep in the WebAssembly layer');
361} else {
362throw new Error(NO_TRAIN_FUNCS_MSG);
363}
364
365return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors);
366} finally {
367wasm.stackRestore(beforeRunStack);
368
369inputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v));
370outputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v));
371inputOutputAllocs.forEach((p) => wasm._free(p));
372
373if (runOptionsHandle !== 0) {
374wasm._OrtReleaseRunOptions(runOptionsHandle);
375}
376runOptionsAllocs.forEach((p) => wasm._free(p));
377}
378};
379
380export const runOptimizerStep = async (
381trainingSessionId: number,
382options: InferenceSession.RunOptions,
383): Promise<void> => {
384const wasm = getInstance();
385
386let runOptionsHandle = 0;
387let runOptionsAllocs: number[] = [];
388
389try {
390[runOptionsHandle, runOptionsAllocs] = setRunOptions(options);
391
392if (wasm._OrtTrainingOptimizerStep) {
393const errCode = wasm._OrtTrainingOptimizerStep(trainingSessionId, runOptionsHandle);
394ifErrCodeCheckLastError(errCode, 'Failed to call OrtTrainingOptimizerStep in the WebAssembly layer');
395} else {
396throw new Error(NO_TRAIN_FUNCS_MSG);
397}
398} finally {
399if (runOptionsHandle !== 0) {
400wasm._OrtReleaseRunOptions(runOptionsHandle);
401}
402runOptionsAllocs.forEach((p) => wasm._free(p));
403}
404};
405
406export const runEvalStep = async (
407trainingSessionId: number,
408inputIndices: number[],
409inputTensors: TensorMetadata[],
410outputIndices: number[],
411outputTensors: Array<TensorMetadata | null>,
412options: InferenceSession.RunOptions,
413): Promise<TensorMetadata[]> => {
414const wasm = getInstance();
415
416const inputCount = inputIndices.length;
417const outputCount = outputIndices.length;
418
419let runOptionsHandle = 0;
420let runOptionsAllocs: number[] = [];
421
422const inputTensorHandles: number[] = [];
423const outputTensorHandles: number[] = [];
424const inputOutputAllocs: number[] = [];
425
426const beforeRunStack = wasm.stackSave();
427
428try {
429// prepare parameters by moving them to heap
430[runOptionsHandle, runOptionsAllocs] = setRunOptions(options);
431
432// handle inputs -- you don't want anything added to the index
433const inputValuesOffset = createAndAllocateTensors(
434trainingSessionId,
435inputIndices,
436inputTensors,
437inputTensorHandles,
438inputOutputAllocs,
4390,
440);
441// handle outputs
442// you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor
443const outputValuesOffset = createAndAllocateTensors(
444trainingSessionId,
445outputIndices,
446outputTensors,
447outputTensorHandles,
448inputOutputAllocs,
449inputCount,
450);
451
452if (wasm._OrtTrainingEvalStep) {
453const errorCode = wasm._OrtTrainingEvalStep(
454trainingSessionId,
455inputValuesOffset,
456inputCount,
457outputValuesOffset,
458outputCount,
459runOptionsHandle,
460);
461
462ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingEvalStep in the WebAssembly layer');
463} else {
464throw new Error(NO_TRAIN_FUNCS_MSG);
465}
466
467return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors);
468} finally {
469wasm.stackRestore(beforeRunStack);
470
471inputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v));
472outputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v));
473inputOutputAllocs.forEach((p) => wasm._free(p));
474
475if (runOptionsHandle !== 0) {
476wasm._OrtReleaseRunOptions(runOptionsHandle);
477}
478runOptionsAllocs.forEach((p) => wasm._free(p));
479}
480};
481
482export const getParametersSize = (trainingSessionId: number, trainableOnly: boolean): number => {
483const wasm = getInstance();
484const stack = wasm.stackSave();
485
486try {
487const sizeOffset = wasm.stackAlloc(4);
488if (wasm._OrtTrainingGetParametersSize) {
489const errorCode = wasm._OrtTrainingGetParametersSize(trainingSessionId, sizeOffset, trainableOnly);
490ifErrCodeCheckLastError(errorCode, "Can't get parameters size");
491
492return wasm.HEAP32[sizeOffset / 4];
493} else {
494throw new Error(NO_TRAIN_FUNCS_MSG);
495}
496} finally {
497wasm.stackRestore(stack);
498}
499};
500
501export const getContiguousParameters = async (
502trainingSessionId: number,
503trainableOnly: boolean,
504): Promise<TensorMetadata> => {
505const wasm = getInstance();
506const stack = wasm.stackSave();
507
508const tensorTypeAsString = 'float32';
509const locationAsString = 'cpu';
510
511const parametersSize = getParametersSize(trainingSessionId, trainableOnly);
512let tensor = 0;
513
514// allocates a buffer of the correct size on the WASM heap
515const paramsByteLength = 4 * parametersSize;
516const paramsOffset = wasm._malloc(paramsByteLength);
517
518// handles the dimensions-related createTensor parameters
519const dims = [parametersSize];
520
521const dimsOffset = wasm.stackAlloc(4);
522const dimsIndex = dimsOffset / 4;
523wasm.HEAP32[dimsIndex] = parametersSize;
524
525try {
526// wraps allocated array in a tensor
527tensor = wasm._OrtCreateTensor(
528tensorDataTypeStringToEnum(tensorTypeAsString),
529paramsOffset,
530paramsByteLength,
531dimsOffset,
532dims.length,
533dataLocationStringToEnum(locationAsString),
534);
535ifErrCodeCheckLastError(
536tensor,
537`Can't create tensor for getContiguousParameters. session=${trainingSessionId}.`,
538false,
539);
540
541if (wasm._OrtTrainingCopyParametersToBuffer) {
542const errCode = wasm._OrtTrainingCopyParametersToBuffer(trainingSessionId, tensor, parametersSize, trainableOnly);
543ifErrCodeCheckLastError(errCode, "Can't get contiguous parameters.");
544} else {
545throw new Error(NO_TRAIN_FUNCS_MSG);
546}
547
548// copies from WASM memory to a JavaScript typed array, which is then put into a TensorMetadata object
549const typedArrayConstructor = tensorTypeToTypedArrayConstructor(tensorTypeAsString);
550const data = new typedArrayConstructor(parametersSize);
551const output: TensorMetadata[] = [];
552new Uint8Array(data.buffer, data.byteOffset, data.byteLength).set(
553wasm.HEAPU8.subarray(paramsOffset, paramsOffset + paramsByteLength),
554);
555output.push([tensorTypeAsString, dims, data, locationAsString]);
556if (output.length !== 1) {
557throw new Error(`something unexpected happened in the getContiguousParameters function. Expected output length of
558one, got ${output.length}`);
559} else {
560return output[0];
561}
562} finally {
563if (tensor !== 0) {
564wasm._OrtReleaseTensor(tensor);
565}
566wasm._free(paramsOffset);
567wasm._free(dimsOffset);
568wasm.stackRestore(stack);
569}
570};
571
572export const loadParametersBuffer = async (
573trainingSessionId: number,
574buffer: Uint8Array,
575trainableOnly: boolean,
576): Promise<void> => {
577const wasm = getInstance();
578const stack = wasm.stackSave();
579
580const tensorTypeAsString = 'float32';
581const locationAsString = 'cpu';
582
583// allocates & copies JavaScript buffer to WASM heap
584const bufferByteLength = buffer.length;
585const bufferCount = bufferByteLength / 4;
586const bufferOffset = wasm._malloc(bufferByteLength);
587wasm.HEAPU8.set(buffer, bufferOffset);
588
589// allocates and handles moving dimensions information to WASM memory
590const dimsOffset = wasm.stackAlloc(4);
591wasm.HEAP32[dimsOffset / 4] = bufferCount;
592const dimsLength = 1;
593let tensor = 0;
594
595try {
596tensor = wasm._OrtCreateTensor(
597tensorDataTypeStringToEnum(tensorTypeAsString),
598bufferOffset,
599bufferByteLength,
600dimsOffset,
601dimsLength,
602dataLocationStringToEnum(locationAsString),
603);
604ifErrCodeCheckLastError(tensor, `Can't create tensor for input/output. session=${trainingSessionId}`, false);
605
606if (wasm._OrtTrainingCopyParametersFromBuffer) {
607const errCode = wasm._OrtTrainingCopyParametersFromBuffer(trainingSessionId, tensor, bufferCount, trainableOnly);
608ifErrCodeCheckLastError(errCode, "Can't copy buffer to parameters.");
609} else {
610throw new Error(NO_TRAIN_FUNCS_MSG);
611}
612} finally {
613if (tensor !== 0) {
614wasm._OrtReleaseTensor(tensor);
615}
616wasm.stackRestore(stack);
617wasm._free(bufferOffset);
618wasm._free(dimsOffset);
619}
620};
621
622export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number): void => {
623const wasm = getInstance();
624
625if (wasm._OrtTrainingReleaseSession) {
626wasm._OrtTrainingReleaseSession(sessionId);
627}
628if (wasm._OrtTrainingReleaseCheckpoint) {
629wasm._OrtTrainingReleaseCheckpoint(checkpointId);
630}
631};
632