onnxruntime

Форк
0
/
inference-handler.ts 
357 строк · 13.8 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { InferenceHandler } from '../../backend';
5
import { Logger } from '../../instrument';
6
import { Tensor } from '../../tensor';
7
import { ShapeUtil } from '../../util';
8

9
import { createPackProgramInfoLoader } from './ops/pack';
10
import { createPackedReshape3DProgramInfoLoader, isReshapeCheap, processDims3D } from './ops/reshape-packed';
11
import { encodeAsUint8 } from './ops/uint8-encode';
12
import { createUnpackProgramInfoLoader } from './ops/unpack';
13
import { WebGLSessionHandler } from './session-handler';
14
import { EncoderUsage } from './texture-data-encoder';
15
import {
16
  calculateTextureWidthAndHeight,
17
  createTextureLayoutFromShape,
18
  createTextureLayoutFromTextureType,
19
} from './texture-layout';
20
import { Artifact, ProgramInfo, ProgramInfoLoader, TextureData, TextureLayout, TextureType } from './types';
21

22
const getProgramInfoUniqueKey = (
23
  programInfo: ProgramInfo | ProgramInfoLoader,
24
  inputTextureDatas: TextureData[],
25
): string => {
26
  const inputs = inputTextureDatas
27
    .map((texture) => `${texture.unpackedShape.join(',')};${texture.width}x${texture.height}`)
28
    .join('_');
29
  let key = programInfo.name;
30
  if (programInfo.cacheHint) {
31
    key += '[' + programInfo.cacheHint + ']';
32
  }
33
  key += ':' + inputs;
34
  return key;
35
};
36

37
export class WebGLInferenceHandler implements InferenceHandler {
38
  private packedTextureDataCache: Map<Tensor.Id, TextureData>;
39
  private unpackedTextureDataCache: Map<Tensor.Id, TextureData>;
40
  constructor(public session: WebGLSessionHandler) {
41
    this.packedTextureDataCache = new Map();
42
    this.unpackedTextureDataCache = new Map();
43
  }
44

45
  /**
46
   * @returns [width, height]
47
   */
48
  calculateTextureWidthAndHeight(shape: readonly number[], textureType: TextureType): [number, number] {
49
    return calculateTextureWidthAndHeight(this.session.layoutStrategy, shape, textureType);
50
  }
51

52
  executeProgram(program: ProgramInfo | ProgramInfoLoader, inputs: readonly Tensor[]): TextureData {
53
    if (inputs.length < program.inputNames.length) {
54
      throw new Error(`Input size mustn't be less than ${program.inputNames.length}.`);
55
    }
56
    if (program.inputNames.length !== program.inputTypes.length) {
57
      throw new Error('input names size does not match input types');
58
    }
59

60
    // create texture info for input
61
    const inputTextureDatas: TextureData[] = [];
62
    for (let i = 0; i < program.inputNames.length; ++i) {
63
      inputTextureDatas[i] = this.getOrCreateTextureData(inputs[i], program.inputTypes[i]);
64
    }
65

66
    const key = getProgramInfoUniqueKey(program, inputTextureDatas);
67
    let artifact = this.session.programManager.getArtifact(key);
68
    const programInfo = artifact
69
      ? artifact.programInfo
70
      : typeof (program as ProgramInfoLoader).get === 'function'
71
        ? (program as ProgramInfoLoader).get()
72
        : (program as ProgramInfo);
73

74
    // create texture info for output
75
    const outputTextureLayout = createTextureLayoutFromTextureType(
76
      this.session.layoutStrategy,
77
      programInfo.output.dims,
78
      programInfo.output.textureType,
79
    );
80
    const outputTextureData = this.createTextureData(outputTextureLayout, programInfo.output.type);
81

82
    if (!artifact) {
83
      artifact = this.session.programManager.build(programInfo, inputTextureDatas, outputTextureData);
84
      this.session.programManager.setArtifact(key, artifact);
85
    }
86

87
    this.runProgram(artifact, inputTextureDatas, outputTextureData);
88
    return outputTextureData;
89
  }
90

91
  run(program: ProgramInfoLoader, inputs: readonly Tensor[]): Tensor {
92
    const outputTextureData = this.executeProgram(program, inputs);
93
    return outputTextureData.tensor;
94
  }
95

96
  private runProgram(artifact: Artifact, inputs: TextureData[], output: TextureData): void {
97
    // input should match
98
    for (let i = 0; i < inputs.length; ++i) {
99
      if (!!inputs[i].isPacked !== (artifact.programInfo.inputTypes[i] === TextureType.packed)) {
100
        throw new Error(`input[${i}] property packed inconsistent`);
101
      }
102
    }
103

104
    // output should match
105
    if (!!output.isPacked !== (artifact.programInfo.output.textureType === TextureType.packed)) {
106
      throw new Error('output property packed inconsistent');
107
    }
108

109
    this.session.programManager.run(artifact, inputs, output);
110
  }
111

112
  /**
113
   * Create a TextureData object from a tensor.
114
   * Usage = EncoderUsage.UploadOnly.
115
   * If a related texture data is found in cache, returns it;
116
   * Otherwise:
117
   *   Creates a new texture layout if not provided;
118
   *   Creates WebGLTexture with the layout;
119
   *   Upload tensor data to the texture;
120
   *   Creates a texture data object associated with the given tensor.
121
   * @param tensor the tensor with data to upload
122
   */
123
  private getOrCreateTextureData(tensor: Tensor, textureType: TextureType) {
124
    let td = this.getTextureData(tensor.dataId, textureType === TextureType.packed);
125

126
    if (!td) {
127
      // check if we have texture data in different type
128
      td = this.getTextureData(tensor.dataId, textureType !== TextureType.packed);
129
      if (td) {
130
        if (textureType === TextureType.packed) {
131
          return this.pack(td);
132
        } else {
133
          return this.unpack(td);
134
        }
135
      }
136
    }
137

138
    if (!td) {
139
      const layout = createTextureLayoutFromTextureType(this.session.layoutStrategy, tensor.dims, textureType);
140

141
      if (textureType === TextureType.packedLastDimension) {
142
        const group = 1;
143
        const channels = 4;
144
        const shape = tensor.dims;
145
        if (shape.length === 4) {
146
          // pre-processing for kernel data of Conv.
147
          //
148
          // TODO: currently this is a hacking to overwrite Conv's weight. The correct way to do this should be:
149
          // 1. implement texture based const-folding
150
          // 2. create a WebGL program "preprocessConvWeight" to do the same work as below
151
          // 3. run the program before dotProduct.
152
          //
153
          const adjustedKernelShape = [shape[0], Math.ceil((shape[1] * shape[2] * shape[3]) / channels)];
154
          const adjustedLayout = createTextureLayoutFromTextureType(
155
            this.session.layoutStrategy,
156
            adjustedKernelShape,
157
            textureType,
158
          );
159
          let buffer = tensor.numberData;
160
          if ((shape[1] * shape[2] * shape[3]) % channels !== 0) {
161
            const numFeatureMaps = shape[0];
162
            const oldRowSize = shape[1] * shape[2] * shape[3];
163
            const newRowSize = Math.ceil((oldRowSize * group) / channels) * channels;
164
            const newSize = numFeatureMaps * newRowSize;
165
            buffer = new Float32Array(newSize);
166
            for (let f = 0; f < numFeatureMaps; ++f) {
167
              const oldOffset = f * oldRowSize;
168
              const newOffset = f * newRowSize + (f % group) * oldRowSize;
169
              buffer.set(tensor.numberData.subarray(oldOffset, oldOffset + oldRowSize), newOffset);
170
            }
171
          }
172
          return this.createTextureData(adjustedLayout, tensor.type, buffer, tensor, EncoderUsage.UploadOnly);
173
        }
174
      }
175

176
      if (textureType === TextureType.packed) {
177
        const unpackedTextureLayout = createTextureLayoutFromShape(this.session.layoutStrategy, tensor.dims, 1, [], {
178
          reverseWH: true,
179
        });
180
        const unpackedTextureData = this.createTextureData(
181
          unpackedTextureLayout,
182
          tensor.type,
183
          tensor.numberData,
184
          tensor,
185
          EncoderUsage.UploadOnly,
186
        );
187
        td = this.pack(unpackedTextureData);
188
      } else {
189
        td = this.createTextureData(layout, tensor.type, tensor.numberData, tensor, EncoderUsage.UploadOnly);
190
      }
191
    }
192
    return td;
193
  }
194

195
  /**
196
   * Create a TextureData object using the given data and bind to the given tensor.
197
   * Usage = EncoderUsage.UploadOnly.
198
   * NOTE: this function is a hack for Conv implementation. should remove this function, after rewriting Conv
199
   * implementation by Graph.Transformer
200
   * @param dataType the tensor data type
201
   * @param data the actual data to upload
202
   * @param tensor the tensor to bind. tensor's data is ignored.
203
   */
204
  createTextureDataFromLayoutBindTensor(
205
    layout: TextureLayout,
206
    dataType: Tensor.DataType,
207
    data: Tensor.NumberType,
208
    tensor: Tensor,
209
  ): TextureData {
210
    return this.createTextureData(layout, dataType, data, tensor, EncoderUsage.UploadOnly);
211
  }
212

213
  private createTextureData(
214
    layout: TextureLayout,
215
    dataType: Tensor.DataType,
216
    data?: Tensor.NumberType,
217
    tensor?: Tensor,
218
    usage?: EncoderUsage,
219
  ): TextureData {
220
    Logger.verbose('InferenceHandler', `Creating TextureData: layout:[${JSON.stringify(layout)}]`);
221
    const texture = this.session.textureManager.createTextureFromLayout(dataType, layout, data, usage);
222
    return this.createTextureDataFromTexture(layout, dataType, texture, tensor);
223
  }
224

225
  reshapeUnpacked(input: Tensor, reshapedDims: readonly number[]): Tensor {
226
    const inputTD = this.getOrCreateTextureData(input, TextureType.unpacked);
227
    const newTextureLayout: TextureLayout = {
228
      channels: inputTD.channels,
229
      height: inputTD.height,
230
      width: inputTD.width,
231
      // handle reshaping into scalar Tensors
232
      shape: reshapedDims.length !== 0 ? reshapedDims : [1],
233
      strides: ShapeUtil.computeStrides(reshapedDims),
234
      unpackedShape: reshapedDims,
235
    };
236
    const newTextureData = this.createTextureDataFromTexture(newTextureLayout, input.type, inputTD.texture);
237
    return newTextureData.tensor;
238
  }
239

240
  reshapePacked(input: Tensor, reshapedDims: readonly number[]): Tensor {
241
    const inputTD = this.getOrCreateTextureData(input, TextureType.packed);
242

243
    // check if the reshape is 'cheap'
244
    if (isReshapeCheap(input.dims, reshapedDims)) {
245
      const newTextureLayout: TextureLayout = {
246
        channels: inputTD.channels,
247
        height: inputTD.height,
248
        width: inputTD.width,
249
        // handle reshaping into scalar Tensors
250
        shape: reshapedDims.length !== 0 ? reshapedDims : [1],
251
        strides: ShapeUtil.computeStrides(reshapedDims),
252
        unpackedShape: reshapedDims,
253
        isPacked: true,
254
      };
255
      const newTextureData = this.createTextureDataFromTexture(newTextureLayout, input.type, inputTD.texture);
256
      return newTextureData.tensor;
257
    }
258

259
    const squeezedInputShape = processDims3D(input.dims);
260
    const squeezedOutputShape = processDims3D(reshapedDims);
261

262
    const squeezedInputTensor = this.reshapePacked(input, squeezedInputShape);
263
    const squeezedOutputTensor = this.run(
264
      createPackedReshape3DProgramInfoLoader(this, squeezedInputTensor, squeezedOutputShape),
265
      [squeezedInputTensor],
266
    );
267
    const outputTensor = this.reshapePacked(squeezedOutputTensor, reshapedDims);
268
    return outputTensor;
269
  }
270

271
  cast(input: Tensor, type: Tensor.DataType): Tensor {
272
    const inputTD = this.getOrCreateTextureData(input, TextureType.unpacked);
273
    const newTextureData = this.createTextureDataFromTexture(inputTD as TextureLayout, type, inputTD.texture);
274
    return newTextureData.tensor;
275
  }
276

277
  private createTextureDataFromTexture(
278
    layout: TextureLayout,
279
    dataType: Tensor.DataType,
280
    texture: WebGLTexture,
281
    tensor?: Tensor,
282
    tensorId?: Tensor.Id,
283
  ) {
284
    const textureData: TextureData = {
285
      ...layout,
286
      tensor:
287
        tensor ||
288
        new Tensor(
289
          layout.unpackedShape,
290
          dataType,
291
          (_id: Tensor.Id) => this.readTexture(textureData),
292
          async (_id: Tensor.Id) => this.readTextureAsync(textureData),
293
          undefined,
294
          tensorId,
295
        ),
296
      texture,
297
    };
298
    this.setTextureData(textureData.tensor.dataId, textureData, layout.isPacked);
299
    return textureData;
300
  }
301

302
  private getTextureData(tensorId: Tensor.Id, isPacked = false): TextureData | undefined {
303
    return this.session.isInitializer(tensorId)
304
      ? this.session.getTextureData(tensorId, isPacked)
305
      : isPacked
306
        ? this.packedTextureDataCache.get(tensorId)
307
        : this.unpackedTextureDataCache.get(tensorId);
308
  }
309
  setTextureData(tensorId: Tensor.Id, td: TextureData, isPacked = false): void {
310
    if (this.session.isInitializer(tensorId)) {
311
      this.session.setTextureData(tensorId, td, isPacked);
312
    } else {
313
      (isPacked ? this.packedTextureDataCache : this.unpackedTextureDataCache).set(tensorId, td);
314
    }
315
  }
316
  isTextureLayoutCached(tensor: Tensor, isPacked = false): boolean {
317
    return !!this.getTextureData(tensor.dataId, isPacked);
318
  }
319

320
  dispose(): void {
321
    this.session.textureManager.clearActiveTextures();
322
    this.packedTextureDataCache.forEach((td) => this.session.textureManager.releaseTexture(td));
323
    this.packedTextureDataCache = new Map();
324
    this.unpackedTextureDataCache.forEach((td) => this.session.textureManager.releaseTexture(td));
325
    this.unpackedTextureDataCache = new Map();
326
  }
327

328
  readTexture(textureData: TextureData): Tensor.NumberType {
329
    if (textureData.isPacked) {
330
      return this.readTexture(this.unpack(textureData));
331
    }
332
    if (!this.session.backend.glContext.isFloat32DownloadSupported) {
333
      return this.session.textureManager.readUint8TextureAsFloat(encodeAsUint8(this, textureData));
334
    }
335
    return this.session.textureManager.readTexture(textureData, textureData.tensor.type, textureData.channels);
336
  }
337

338
  async readTextureAsync(textureData: TextureData): Promise<Tensor.NumberType> {
339
    if (textureData.isPacked) {
340
      return this.readTextureAsync(this.unpack(textureData));
341
    }
342
    if (!this.session.backend.glContext.isFloat32DownloadSupported) {
343
      return this.session.textureManager.readUint8TextureAsFloat(encodeAsUint8(this, textureData));
344
    }
345
    return this.session.textureManager.readTextureAsync(textureData, textureData.tensor.type, textureData.channels);
346
  }
347

348
  pack(input: TextureData): TextureData {
349
    const outputTextureData = this.executeProgram(createPackProgramInfoLoader(this, input.tensor), [input.tensor]);
350
    return outputTextureData;
351
  }
352

353
  unpack(input: TextureData): TextureData {
354
    const outputTextureData = this.executeProgram(createUnpackProgramInfoLoader(this, input.tensor), [input.tensor]);
355
    return outputTextureData;
356
  }
357
}
358

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.