onnxruntime

Форк
0
91 строка · 3.4 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { Tensor } from '../../../tensor';
5
import { ShapeUtil } from '../../../util';
6
import { getGlsl } from '../glsl-source';
7
import { WebGLInferenceHandler } from '../inference-handler';
8
import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types';
9

10
import { getActivationSnippet, InternalActivationAttributes } from './fuse-utils';
11
import { calculateIm2ColDims } from './im2col';
12

13
const createDotProductProgramMetadata = (hasBias: boolean, attributes: InternalActivationAttributes) => ({
14
  name: 'ConvDotProduct',
15
  inputNames: hasBias ? ['Im2Col', 'K', 'B'] : ['Im2Col', 'K'],
16
  inputTypes: hasBias
17
    ? [TextureType.unpacked, TextureType.packedLastDimension, TextureType.unpacked]
18
    : [TextureType.unpacked, TextureType.packedLastDimension],
19
  cacheKey: attributes.activationCacheKey,
20
});
21

22
const createDotProductProgramInfo = (
23
  inferenceHandler: WebGLInferenceHandler,
24
  metadata: ProgramMetadata,
25
  inputs: readonly Tensor[],
26
  outputShape: number[],
27
  attributes: InternalActivationAttributes,
28
): ProgramInfo => {
29
  const xshape = inputs[0].dims;
30
  const kshape = inputs[1].dims;
31
  const adjustedKernelShape = [kshape[0], Math.ceil((xshape[1] * kshape[2] * kshape[3]) / 4)];
32
  const im2colShape = calculateIm2ColDims(xshape, kshape, outputShape);
33
  const [kWidth, kHeight] = inferenceHandler.calculateTextureWidthAndHeight(
34
    adjustedKernelShape,
35
    TextureType.packedLastDimension,
36
  );
37

38
  const im2colStrides = ShapeUtil.computeStrides(im2colShape);
39
  const [im2colWidth, im2colHeight] = inferenceHandler.calculateTextureWidthAndHeight(
40
    im2colShape,
41
    TextureType.packedLastDimension,
42
  );
43
  const rank = outputShape.length;
44

45
  const initValue = inputs.length < 3 ? '0.0' : '_B(b)';
46
  const sharedDim = Math.ceil((xshape[1] * kshape[2] * kshape[3]) / 4);
47
  const { activationFunction, applyActivation } = getActivationSnippet(attributes);
48
  const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
49
  const shaderSource = `
50
${activationFunction}
51
float process(int indices[${rank}]) {
52
  int b[1];
53
  b[0] = indices[1];
54
  int im2col[4];
55
  im2col[0] = indices[0];
56
  im2col[1] = indices[2];
57
  im2col[2] = indices[3];
58
  int im2colOffset = im2col[0] * ${im2colStrides[0]} + im2col[1] * ${im2colStrides[1]} + im2col[2] * ${
59
    im2colStrides[2]
60
  };
61
  int kernelOffset = indices[1] * ${adjustedKernelShape[1]};
62
  float value = ${initValue};
63
  for (int i = 0; i < ${sharedDim}; ++i) {
64
    vec2 im2colCoords = offsetToCoords(im2colOffset, ${im2colWidth}, ${im2colHeight});
65
    vec2 kernelCoords = offsetToCoords(kernelOffset, ${kWidth}, ${kHeight});
66
    value += dot(${glsl.texture2D}(Im2Col, im2colCoords), ${glsl.texture2D}(K, kernelCoords));
67
    ++im2colOffset;
68
    ++kernelOffset;
69
  }
70
  ${applyActivation}
71
  return value;
72
}`;
73
  return {
74
    ...metadata,
75
    output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked },
76
    shaderSource,
77
  };
78
};
79

80
export const createDotProductProgramInfoLoader = (
81
  inferenceHandler: WebGLInferenceHandler,
82
  inputs: readonly Tensor[],
83
  outputShape: number[],
84
  attributes: InternalActivationAttributes,
85
): ProgramInfoLoader => {
86
  const metadata = createDotProductProgramMetadata(inputs.length > 2, attributes);
87
  return {
88
    ...metadata,
89
    get: () => createDotProductProgramInfo(inferenceHandler, metadata, inputs, outputShape, attributes),
90
  };
91
};
92

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.