onnxruntime

Форк
0
99 строк · 3.7 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { Logger } from '../../../instrument';
5
import { Tensor } from '../../../tensor';
6
import { getGlsl } from '../glsl-source';
7
import { WebGLInferenceHandler } from '../inference-handler';
8
import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types';
9

10
import { calculateOutputShape, ConvAttributes } from './conv';
11
import { getActivationSnippet } from './fuse-utils';
12

13
const createUnpackedGroupedConvProgramMetadata = (hasBias: boolean, cacheHint: string): ProgramMetadata => ({
14
  name: 'GroupedConv',
15
  inputNames: hasBias ? ['X', 'W', 'Bias'] : ['X', 'W'],
16
  inputTypes: hasBias
17
    ? [TextureType.unpacked, TextureType.unpacked, TextureType.unpacked]
18
    : [TextureType.unpacked, TextureType.unpacked],
19
  cacheHint,
20
});
21

22
const createUnpackedGroupedConvProgramInfo = (
23
  inferenceHandler: WebGLInferenceHandler,
24
  inputs: readonly Tensor[],
25
  metadata: ProgramMetadata,
26
  attributes: ConvAttributes,
27
): ProgramInfo => {
28
  const hasBias = inputs.length > 2;
29
  const processBias = hasBias ? 'value += getBias(output_channel);' : '';
30
  const xShape = inputs[0].dims.slice();
31
  const wShape = inputs[1].dims.slice();
32
  const outputChannelsPerGroup = wShape[0] / attributes.group;
33
  Logger.verbose(
34
    'GroupedConv',
35
    `autpPad:${attributes.autoPad}, dilations:${attributes.dilations}, group:${attributes.group}, kernelShape:${
36
      attributes.kernelShape
37
    }, pads:${attributes.pads}, strides:${attributes.strides}`,
38
  );
39
  const outputShape = calculateOutputShape(xShape, wShape, attributes.dilations, attributes.pads, attributes.strides);
40
  const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
41
  const { activationFunction, applyActivation } = getActivationSnippet(attributes);
42

43
  const shaderSource = `
44
  const ivec2 strides = ivec2(${attributes.strides[0]}, ${attributes.strides[1]});
45
  const ivec2 pads = ivec2(${attributes.pads[0]}, ${attributes.pads[1]});
46
  ${activationFunction}
47
  void main() {
48
    ivec4 coords = getOutputCoords();
49
    int batch = coords.x;
50
    int output_channel = coords.y;
51
    ivec2 xRCCorner = coords.zw * strides - pads;
52
    int group_id = output_channel / ${outputChannelsPerGroup};
53

54
    float value = 0.0;
55
    for (int wInChannel = 0; wInChannel < ${wShape[1]}; wInChannel++) {
56
      int input_channel = group_id * ${wShape[1]} + wInChannel;
57
      for (int wHeight = 0; wHeight < ${wShape[2]}; wHeight++) {
58
        int xHeight = xRCCorner.x + wHeight * ${attributes.dilations[0]};
59

60
        if (xHeight < 0 || xHeight >= ${xShape[2]}) {
61
          continue;
62
        }
63

64
        for (int wWidth = 0; wWidth < ${wShape[3]}; wWidth++) {
65
          int xWidth = xRCCorner.y + wWidth * ${attributes.dilations[1]};
66
          if (xWidth < 0 || xWidth >= ${xShape[3]}) {
67
            continue;
68
          }
69

70
          float xVal = getX(batch, input_channel, xWidth, xHeight);
71
          float wVal = getW(output_channel, wInChannel, wWidth, wHeight);
72
          value += xVal*wVal;
73
        }
74
      }
75
    }
76
    ${processBias}
77
    ${applyActivation}
78
    ${glsl.output} = vec4(value, .0, .0, .0);
79
  }
80
`;
81
  return {
82
    ...metadata,
83
    output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked },
84
    shaderSource,
85
    hasMain: true,
86
  };
87
};
88

89
export const createUnpackedGroupedConvProgramInfoLoader = (
90
  inferenceHandler: WebGLInferenceHandler,
91
  inputs: readonly Tensor[],
92
  attributes: ConvAttributes,
93
): ProgramInfoLoader => {
94
  const metadata = createUnpackedGroupedConvProgramMetadata(inputs.length > 2, attributes.cacheKey);
95
  return {
96
    ...metadata,
97
    get: () => createUnpackedGroupedConvProgramInfo(inferenceHandler, inputs, metadata, attributes),
98
  };
99
};
100

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.