onnxruntime

Форк
0
/
session-handler-inference.ts 
139 строк · 4.4 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import {
5
  InferenceSession,
6
  InferenceSessionHandler,
7
  SessionHandler,
8
  Tensor,
9
  TRACE_FUNC_BEGIN,
10
  TRACE_FUNC_END,
11
} from 'onnxruntime-common';
12

13
import { SerializableInternalBuffer, TensorMetadata } from './proxy-messages';
14
import { copyFromExternalBuffer, createSession, endProfiling, releaseSession, run } from './proxy-wrapper';
15
import { isGpuBufferSupportedType } from './wasm-common';
16
import { isNode } from './wasm-utils-env';
17
import { loadFile } from './wasm-utils-load-file';
18

19
export const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => {
20
  switch (tensor.location) {
21
    case 'cpu':
22
      return [tensor.type, tensor.dims, tensor.data, 'cpu'];
23
    case 'gpu-buffer':
24
      return [tensor.type, tensor.dims, { gpuBuffer: tensor.gpuBuffer }, 'gpu-buffer'];
25
    default:
26
      throw new Error(`invalid data location: ${tensor.location} for ${getName()}`);
27
  }
28
};
29

30
export const decodeTensorMetadata = (tensor: TensorMetadata): Tensor => {
31
  switch (tensor[3]) {
32
    case 'cpu':
33
      return new Tensor(tensor[0], tensor[2], tensor[1]);
34
    case 'gpu-buffer': {
35
      const dataType = tensor[0];
36
      if (!isGpuBufferSupportedType(dataType)) {
37
        throw new Error(`not supported data type: ${dataType} for deserializing GPU tensor`);
38
      }
39
      const { gpuBuffer, download, dispose } = tensor[2];
40
      return Tensor.fromGpuBuffer(gpuBuffer, { dataType, dims: tensor[1], download, dispose });
41
    }
42
    default:
43
      throw new Error(`invalid data location: ${tensor[3]}`);
44
  }
45
};
46

47
export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHandler {
48
  private sessionId: number;
49

50
  inputNames: string[];
51
  outputNames: string[];
52

53
  async fetchModelAndCopyToWasmMemory(path: string): Promise<SerializableInternalBuffer> {
54
    // fetch model from url and move to wasm heap.
55
    return copyFromExternalBuffer(await loadFile(path));
56
  }
57

58
  async loadModel(pathOrBuffer: string | Uint8Array, options?: InferenceSession.SessionOptions): Promise<void> {
59
    TRACE_FUNC_BEGIN();
60
    let model: Parameters<typeof createSession>[0];
61

62
    if (typeof pathOrBuffer === 'string') {
63
      if (isNode) {
64
        // node
65
        model = await loadFile(pathOrBuffer);
66
      } else {
67
        // browser
68
        // fetch model and copy to wasm heap.
69
        model = await this.fetchModelAndCopyToWasmMemory(pathOrBuffer);
70
      }
71
    } else {
72
      model = pathOrBuffer;
73
    }
74

75
    [this.sessionId, this.inputNames, this.outputNames] = await createSession(model, options);
76
    TRACE_FUNC_END();
77
  }
78

79
  async dispose(): Promise<void> {
80
    return releaseSession(this.sessionId);
81
  }
82

83
  async run(
84
    feeds: SessionHandler.FeedsType,
85
    fetches: SessionHandler.FetchesType,
86
    options: InferenceSession.RunOptions,
87
  ): Promise<SessionHandler.ReturnType> {
88
    TRACE_FUNC_BEGIN();
89
    const inputArray: Tensor[] = [];
90
    const inputIndices: number[] = [];
91
    Object.entries(feeds).forEach((kvp) => {
92
      const name = kvp[0];
93
      const tensor = kvp[1];
94
      const index = this.inputNames.indexOf(name);
95
      if (index === -1) {
96
        throw new Error(`invalid input '${name}'`);
97
      }
98
      inputArray.push(tensor);
99
      inputIndices.push(index);
100
    });
101

102
    const outputArray: Array<Tensor | null> = [];
103
    const outputIndices: number[] = [];
104
    Object.entries(fetches).forEach((kvp) => {
105
      const name = kvp[0];
106
      const tensor = kvp[1];
107
      const index = this.outputNames.indexOf(name);
108
      if (index === -1) {
109
        throw new Error(`invalid output '${name}'`);
110
      }
111
      outputArray.push(tensor);
112
      outputIndices.push(index);
113
    });
114

115
    const inputs = inputArray.map((t, i) =>
116
      encodeTensorMetadata(t, () => `input "${this.inputNames[inputIndices[i]]}"`),
117
    );
118
    const outputs = outputArray.map((t, i) =>
119
      t ? encodeTensorMetadata(t, () => `output "${this.outputNames[outputIndices[i]]}"`) : null,
120
    );
121

122
    const results = await run(this.sessionId, inputIndices, inputs, outputIndices, outputs, options);
123

124
    const resultMap: SessionHandler.ReturnType = {};
125
    for (let i = 0; i < results.length; i++) {
126
      resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]);
127
    }
128
    TRACE_FUNC_END();
129
    return resultMap;
130
  }
131

132
  startProfiling(): void {
133
    // TODO: implement profiling
134
  }
135

136
  endProfiling(): void {
137
    void endProfiling(this.sessionId);
138
  }
139
}
140

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.