onnxruntime

Форк
0
/
session-handler-training.ts 
198 строк · 7.5 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessionHandler } from 'onnxruntime-common';
5

6
import { SerializableInternalBuffer, TensorMetadata } from './proxy-messages';
7
import { decodeTensorMetadata, encodeTensorMetadata } from './session-handler-inference';
8
import { copyFromExternalBuffer } from './wasm-core-impl';
9
import {
10
  createCheckpointHandle,
11
  createTrainingSessionHandle,
12
  getContiguousParameters,
13
  getModelInputOutputNames,
14
  getParametersSize,
15
  lazyResetGrad,
16
  loadParametersBuffer,
17
  releaseTrainingSessionAndCheckpoint,
18
  runEvalStep,
19
  runOptimizerStep,
20
  runTrainStep,
21
} from './wasm-training-core-impl';
22

23
export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler {
24
  private sessionId: number;
25
  private checkpointId: number;
26

27
  inputNames: string[];
28
  outputNames: string[];
29

30
  evalInputNames: string[] = [];
31
  evalOutputNames: string[] = [];
32

33
  async uriOrBufferToHeap(uriOrBuffer: string | Uint8Array): Promise<SerializableInternalBuffer> {
34
    let buffer: Uint8Array;
35
    if (typeof uriOrBuffer === 'string') {
36
      const response = await fetch(uriOrBuffer);
37
      const arrayBuffer = await response.arrayBuffer();
38
      buffer = new Uint8Array(arrayBuffer);
39
    } else {
40
      buffer = uriOrBuffer;
41
    }
42
    return copyFromExternalBuffer(buffer);
43
  }
44

45
  async createTrainingSession(
46
    checkpointStateUriOrBuffer: string | Uint8Array,
47
    trainModelUriOrBuffer: string | Uint8Array,
48
    evalModelUriOrBuffer: string | Uint8Array,
49
    optimizerModelUriOrBuffer: string | Uint8Array,
50
    options: InferenceSession.SessionOptions,
51
  ) {
52
    const checkpointData: SerializableInternalBuffer = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer);
53
    const trainModelData: SerializableInternalBuffer = await this.uriOrBufferToHeap(trainModelUriOrBuffer);
54
    // 0 is supposed to be the nullptr
55
    let evalModelData: SerializableInternalBuffer = [0, 0];
56
    let optimizerModelData: SerializableInternalBuffer = [0, 0];
57

58
    if (evalModelUriOrBuffer !== '') {
59
      evalModelData = await this.uriOrBufferToHeap(evalModelUriOrBuffer);
60
    }
61
    if (optimizerModelUriOrBuffer !== '') {
62
      optimizerModelData = await this.uriOrBufferToHeap(optimizerModelUriOrBuffer);
63
    }
64

65
    this.checkpointId = createCheckpointHandle(checkpointData);
66
    this.sessionId = createTrainingSessionHandle(
67
      this.checkpointId,
68
      trainModelData,
69
      evalModelData,
70
      optimizerModelData,
71
      options,
72
    );
73
    [this.inputNames, this.outputNames] = getModelInputOutputNames(this.sessionId, false);
74
    if (evalModelUriOrBuffer !== '') {
75
      [this.evalInputNames, this.evalOutputNames] = getModelInputOutputNames(this.sessionId, true);
76
    }
77
  }
78

79
  /**
80
   * Helper method that converts a feeds or fetches datatype to two arrays, one of values and one that stores the
81
   * corresponding name as a number referring to the index in the list of names provided.
82
   *
83
   * @param feeds meant to match either SessionHandler.FeedsType or SessionHandler.FetchesType
84
   * @param names either inputNames or outputNames
85
   * @returns a tuple of a list of values and a list of indices.
86
   */
87
  convertMapIntoValuesArrayAndIndicesArray<T, U>(
88
    feeds: { [name: string]: T },
89
    names: string[],
90
    mapFunc: (val: T, index: number) => U,
91
  ): [T[], number[], U[]] {
92
    const values: T[] = [];
93
    const indices: number[] = [];
94
    Object.entries(feeds).forEach((kvp) => {
95
      const name = kvp[0];
96
      const tensor = kvp[1];
97
      const index = names.indexOf(name);
98
      if (index === -1) {
99
        throw new Error(`invalid input '${name}`);
100
      }
101
      values.push(tensor);
102
      indices.push(index);
103
    });
104

105
    const uList = values.map(mapFunc);
106
    return [values, indices, uList];
107
  }
108

109
  /**
110
   * Helper method that converts the TensorMetadata that the wasm-core functions return to the
111
   * SessionHandler.ReturnType. Any outputs in the provided outputArray that are falsy will be populated with the
112
   * corresponding result.
113
   *
114
   * @param results used to populate the resultMap if there is no value for that outputName already
115
   * @param outputArray used to populate the resultMap. If null or undefined, use the corresponding result from results
116
   * @param outputIndices specifies which outputName the corresponding value for outputArray refers to.
117
   * @returns a map of output names and OnnxValues.
118
   */
119
  convertTensorMetadataToReturnType(
120
    results: TensorMetadata[],
121
    outputArray: Array<Tensor | null>,
122
    outputIndices: number[],
123
  ): SessionHandler.ReturnType {
124
    const resultMap: SessionHandler.ReturnType = {};
125
    for (let i = 0; i < results.length; i++) {
126
      resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]);
127
    }
128
    return resultMap;
129
  }
130

131
  async lazyResetGrad(): Promise<void> {
132
    await lazyResetGrad(this.sessionId);
133
  }
134

135
  async runTrainStep(
136
    feeds: SessionHandler.FeedsType,
137
    fetches: SessionHandler.FetchesType,
138
    options: InferenceSession.RunOptions,
139
  ): Promise<SessionHandler.ReturnType> {
140
    const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray<Tensor, TensorMetadata>(
141
      feeds,
142
      this.inputNames,
143
      (t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.inputNames[inputIndices[i]]}"`),
144
    );
145

146
    const [outputArray, outputIndices, outputs] = this.convertMapIntoValuesArrayAndIndicesArray<
147
      Tensor | null,
148
      TensorMetadata | null
149
    >(fetches, this.outputNames, (t, i): TensorMetadata | null =>
150
      t ? encodeTensorMetadata(t, () => `output "${this.outputNames[outputIndices[i]]}"`) : null,
151
    );
152

153
    const results = await runTrainStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options);
154
    return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices);
155
  }
156

157
  async runOptimizerStep(options: InferenceSession.RunOptions): Promise<void> {
158
    await runOptimizerStep(this.sessionId, options);
159
  }
160

161
  async runEvalStep(
162
    feeds: SessionHandler.FeedsType,
163
    fetches: SessionHandler.FetchesType,
164
    options: InferenceSession.RunOptions,
165
  ): Promise<SessionHandler.ReturnType> {
166
    const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray<Tensor, TensorMetadata>(
167
      feeds,
168
      this.evalInputNames,
169
      (t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.evalInputNames[inputIndices[i]]}"`),
170
    );
171

172
    const [outputArray, outputIndices, outputs] = this.convertMapIntoValuesArrayAndIndicesArray<
173
      Tensor | null,
174
      TensorMetadata | null
175
    >(fetches, this.evalOutputNames, (t, i): TensorMetadata | null =>
176
      t ? encodeTensorMetadata(t, () => `output "${this.evalOutputNames[outputIndices[i]]}"`) : null,
177
    );
178

179
    const results = await runEvalStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options);
180
    return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices);
181
  }
182

183
  async getParametersSize(trainableOnly: boolean): Promise<number> {
184
    return getParametersSize(this.sessionId, trainableOnly);
185
  }
186

187
  async loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise<void> {
188
    await loadParametersBuffer(this.sessionId, array, trainableOnly);
189
  }
190
  async getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue> {
191
    const tensorResult = await getContiguousParameters(this.sessionId, trainableOnly);
192
    return decodeTensorMetadata(tensorResult);
193
  }
194

195
  async dispose(): Promise<void> {
196
    return releaseTrainingSessionAndCheckpoint(this.checkpointId, this.sessionId);
197
  }
198
}
199

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.