onnxruntime
198 строк · 7.5 Кб
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4import { InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessionHandler } from 'onnxruntime-common';5
6import { SerializableInternalBuffer, TensorMetadata } from './proxy-messages';7import { decodeTensorMetadata, encodeTensorMetadata } from './session-handler-inference';8import { copyFromExternalBuffer } from './wasm-core-impl';9import {10createCheckpointHandle,11createTrainingSessionHandle,12getContiguousParameters,13getModelInputOutputNames,14getParametersSize,15lazyResetGrad,16loadParametersBuffer,17releaseTrainingSessionAndCheckpoint,18runEvalStep,19runOptimizerStep,20runTrainStep,21} from './wasm-training-core-impl';22
23export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler {24private sessionId: number;25private checkpointId: number;26
27inputNames: string[];28outputNames: string[];29
30evalInputNames: string[] = [];31evalOutputNames: string[] = [];32
33async uriOrBufferToHeap(uriOrBuffer: string | Uint8Array): Promise<SerializableInternalBuffer> {34let buffer: Uint8Array;35if (typeof uriOrBuffer === 'string') {36const response = await fetch(uriOrBuffer);37const arrayBuffer = await response.arrayBuffer();38buffer = new Uint8Array(arrayBuffer);39} else {40buffer = uriOrBuffer;41}42return copyFromExternalBuffer(buffer);43}44
45async createTrainingSession(46checkpointStateUriOrBuffer: string | Uint8Array,47trainModelUriOrBuffer: string | Uint8Array,48evalModelUriOrBuffer: string | Uint8Array,49optimizerModelUriOrBuffer: string | Uint8Array,50options: InferenceSession.SessionOptions,51) {52const checkpointData: SerializableInternalBuffer = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer);53const trainModelData: SerializableInternalBuffer = await this.uriOrBufferToHeap(trainModelUriOrBuffer);54// 0 is supposed to be the nullptr55let evalModelData: SerializableInternalBuffer = [0, 0];56let optimizerModelData: SerializableInternalBuffer = [0, 0];57
58if (evalModelUriOrBuffer !== '') {59evalModelData = await this.uriOrBufferToHeap(evalModelUriOrBuffer);60}61if (optimizerModelUriOrBuffer !== '') {62optimizerModelData = await this.uriOrBufferToHeap(optimizerModelUriOrBuffer);63}64
65this.checkpointId = createCheckpointHandle(checkpointData);66this.sessionId = createTrainingSessionHandle(67this.checkpointId,68trainModelData,69evalModelData,70optimizerModelData,71options,72);73[this.inputNames, this.outputNames] = getModelInputOutputNames(this.sessionId, false);74if (evalModelUriOrBuffer !== '') {75[this.evalInputNames, this.evalOutputNames] = getModelInputOutputNames(this.sessionId, true);76}77}78
79/**80* Helper method that converts a feeds or fetches datatype to two arrays, one of values and one that stores the
81* corresponding name as a number referring to the index in the list of names provided.
82*
83* @param feeds meant to match either SessionHandler.FeedsType or SessionHandler.FetchesType
84* @param names either inputNames or outputNames
85* @returns a tuple of a list of values and a list of indices.
86*/
87convertMapIntoValuesArrayAndIndicesArray<T, U>(88feeds: { [name: string]: T },89names: string[],90mapFunc: (val: T, index: number) => U,91): [T[], number[], U[]] {92const values: T[] = [];93const indices: number[] = [];94Object.entries(feeds).forEach((kvp) => {95const name = kvp[0];96const tensor = kvp[1];97const index = names.indexOf(name);98if (index === -1) {99throw new Error(`invalid input '${name}`);100}101values.push(tensor);102indices.push(index);103});104
105const uList = values.map(mapFunc);106return [values, indices, uList];107}108
109/**110* Helper method that converts the TensorMetadata that the wasm-core functions return to the
111* SessionHandler.ReturnType. Any outputs in the provided outputArray that are falsy will be populated with the
112* corresponding result.
113*
114* @param results used to populate the resultMap if there is no value for that outputName already
115* @param outputArray used to populate the resultMap. If null or undefined, use the corresponding result from results
116* @param outputIndices specifies which outputName the corresponding value for outputArray refers to.
117* @returns a map of output names and OnnxValues.
118*/
119convertTensorMetadataToReturnType(120results: TensorMetadata[],121outputArray: Array<Tensor | null>,122outputIndices: number[],123): SessionHandler.ReturnType {124const resultMap: SessionHandler.ReturnType = {};125for (let i = 0; i < results.length; i++) {126resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]);127}128return resultMap;129}130
131async lazyResetGrad(): Promise<void> {132await lazyResetGrad(this.sessionId);133}134
135async runTrainStep(136feeds: SessionHandler.FeedsType,137fetches: SessionHandler.FetchesType,138options: InferenceSession.RunOptions,139): Promise<SessionHandler.ReturnType> {140const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray<Tensor, TensorMetadata>(141feeds,142this.inputNames,143(t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.inputNames[inputIndices[i]]}"`),144);145
146const [outputArray, outputIndices, outputs] = this.convertMapIntoValuesArrayAndIndicesArray<147Tensor | null,148TensorMetadata | null149>(fetches, this.outputNames, (t, i): TensorMetadata | null =>150t ? encodeTensorMetadata(t, () => `output "${this.outputNames[outputIndices[i]]}"`) : null,151);152
153const results = await runTrainStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options);154return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices);155}156
157async runOptimizerStep(options: InferenceSession.RunOptions): Promise<void> {158await runOptimizerStep(this.sessionId, options);159}160
161async runEvalStep(162feeds: SessionHandler.FeedsType,163fetches: SessionHandler.FetchesType,164options: InferenceSession.RunOptions,165): Promise<SessionHandler.ReturnType> {166const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray<Tensor, TensorMetadata>(167feeds,168this.evalInputNames,169(t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.evalInputNames[inputIndices[i]]}"`),170);171
172const [outputArray, outputIndices, outputs] = this.convertMapIntoValuesArrayAndIndicesArray<173Tensor | null,174TensorMetadata | null175>(fetches, this.evalOutputNames, (t, i): TensorMetadata | null =>176t ? encodeTensorMetadata(t, () => `output "${this.evalOutputNames[outputIndices[i]]}"`) : null,177);178
179const results = await runEvalStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options);180return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices);181}182
183async getParametersSize(trainableOnly: boolean): Promise<number> {184return getParametersSize(this.sessionId, trainableOnly);185}186
187async loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise<void> {188await loadParametersBuffer(this.sessionId, array, trainableOnly);189}190async getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue> {191const tensorResult = await getContiguousParameters(this.sessionId, trainableOnly);192return decodeTensorMetadata(tensorResult);193}194
195async dispose(): Promise<void> {196return releaseTrainingSessionAndCheckpoint(this.checkpointId, this.sessionId);197}198}
199