onnxruntime
139 строк · 4.4 Кб
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4import {5InferenceSession,6InferenceSessionHandler,7SessionHandler,8Tensor,9TRACE_FUNC_BEGIN,10TRACE_FUNC_END,11} from 'onnxruntime-common';12
13import { SerializableInternalBuffer, TensorMetadata } from './proxy-messages';14import { copyFromExternalBuffer, createSession, endProfiling, releaseSession, run } from './proxy-wrapper';15import { isGpuBufferSupportedType } from './wasm-common';16import { isNode } from './wasm-utils-env';17import { loadFile } from './wasm-utils-load-file';18
19export const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => {20switch (tensor.location) {21case 'cpu':22return [tensor.type, tensor.dims, tensor.data, 'cpu'];23case 'gpu-buffer':24return [tensor.type, tensor.dims, { gpuBuffer: tensor.gpuBuffer }, 'gpu-buffer'];25default:26throw new Error(`invalid data location: ${tensor.location} for ${getName()}`);27}28};29
30export const decodeTensorMetadata = (tensor: TensorMetadata): Tensor => {31switch (tensor[3]) {32case 'cpu':33return new Tensor(tensor[0], tensor[2], tensor[1]);34case 'gpu-buffer': {35const dataType = tensor[0];36if (!isGpuBufferSupportedType(dataType)) {37throw new Error(`not supported data type: ${dataType} for deserializing GPU tensor`);38}39const { gpuBuffer, download, dispose } = tensor[2];40return Tensor.fromGpuBuffer(gpuBuffer, { dataType, dims: tensor[1], download, dispose });41}42default:43throw new Error(`invalid data location: ${tensor[3]}`);44}45};46
47export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHandler {48private sessionId: number;49
50inputNames: string[];51outputNames: string[];52
53async fetchModelAndCopyToWasmMemory(path: string): Promise<SerializableInternalBuffer> {54// fetch model from url and move to wasm heap.55return copyFromExternalBuffer(await loadFile(path));56}57
58async loadModel(pathOrBuffer: string | Uint8Array, options?: InferenceSession.SessionOptions): Promise<void> {59TRACE_FUNC_BEGIN();60let model: Parameters<typeof createSession>[0];61
62if (typeof pathOrBuffer === 'string') {63if (isNode) {64// node65model = await loadFile(pathOrBuffer);66} else {67// browser68// fetch model and copy to wasm heap.69model = await this.fetchModelAndCopyToWasmMemory(pathOrBuffer);70}71} else {72model = pathOrBuffer;73}74
75[this.sessionId, this.inputNames, this.outputNames] = await createSession(model, options);76TRACE_FUNC_END();77}78
79async dispose(): Promise<void> {80return releaseSession(this.sessionId);81}82
83async run(84feeds: SessionHandler.FeedsType,85fetches: SessionHandler.FetchesType,86options: InferenceSession.RunOptions,87): Promise<SessionHandler.ReturnType> {88TRACE_FUNC_BEGIN();89const inputArray: Tensor[] = [];90const inputIndices: number[] = [];91Object.entries(feeds).forEach((kvp) => {92const name = kvp[0];93const tensor = kvp[1];94const index = this.inputNames.indexOf(name);95if (index === -1) {96throw new Error(`invalid input '${name}'`);97}98inputArray.push(tensor);99inputIndices.push(index);100});101
102const outputArray: Array<Tensor | null> = [];103const outputIndices: number[] = [];104Object.entries(fetches).forEach((kvp) => {105const name = kvp[0];106const tensor = kvp[1];107const index = this.outputNames.indexOf(name);108if (index === -1) {109throw new Error(`invalid output '${name}'`);110}111outputArray.push(tensor);112outputIndices.push(index);113});114
115const inputs = inputArray.map((t, i) =>116encodeTensorMetadata(t, () => `input "${this.inputNames[inputIndices[i]]}"`),117);118const outputs = outputArray.map((t, i) =>119t ? encodeTensorMetadata(t, () => `output "${this.outputNames[outputIndices[i]]}"`) : null,120);121
122const results = await run(this.sessionId, inputIndices, inputs, outputIndices, outputs, options);123
124const resultMap: SessionHandler.ReturnType = {};125for (let i = 0; i < results.length; i++) {126resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]);127}128TRACE_FUNC_END();129return resultMap;130}131
132startProfiling(): void {133// TODO: implement profiling134}135
136endProfiling(): void {137void endProfiling(this.sessionId);138}139}
140