onnxruntime

Форк
0
101 строка · 3.2 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { Tensor } from '../../../tensor';
5
import { getGlsl } from '../glsl-source';
6
import { WebGLInferenceHandler } from '../inference-handler';
7
import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types';
8

9
import { ConvAttributes } from './conv';
10
import { unpackFromChannel } from './packing-utils';
11

12
const createPackedIm2ColProgramMetadata = (cacheHint: string) => ({
13
  name: 'Im2Col (packed)',
14
  inputNames: ['A'],
15
  inputTypes: [TextureType.packed],
16
  cacheHint,
17
});
18

19
const createPackedIm2ColProgramInfo = (
20
  inferenceHandler: WebGLInferenceHandler,
21
  metadata: ProgramMetadata,
22
  x: Tensor,
23
  w: Tensor,
24
  outputShape: readonly number[],
25
  attributes: ConvAttributes,
26
): ProgramInfo => {
27
  const xshape = x.dims;
28
  const wshape = w.dims;
29
  const rowDim = 2;
30
  const colDim = 3;
31
  const rank = outputShape.length;
32
  const im2colShape = [wshape[1] * wshape[2] * wshape[3], outputShape[2] * outputShape[3]];
33
  const kernelSize = wshape[2] * wshape[3];
34
  const unpackChannel = unpackFromChannel();
35
  const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
36
  let unrolled = '';
37

38
  for (let row = 0; row <= 1; row++) {
39
    for (let col = 0; col <= 1; col++) {
40
      unrolled += `
41
            blockIndex = rc.x + ${col};
42
            pos = rc.y + ${row};
43

44
            if(blockIndex < ${im2colShape[1]} && pos < ${im2colShape[0]}) {
45
              offsetY = int(blockIndex / (${outputShape[rank - 1]})) * ${attributes.strides[0]} -
46
                ${attributes.pads[0]};
47
              d0 = offsetY + ${attributes.dilations[0]} * (imod(pos, ${kernelSize}) / ${wshape[2]});
48

49
              if(d0 < ${xshape[rowDim]} && d0 >= 0) {
50
                offsetX = imod(blockIndex, ${outputShape[rank - 1]}) * ${attributes.strides[1]} -
51
                  ${attributes.pads[1]};
52
                d1 = offsetX + ${attributes.dilations[1]} * imod(imod(pos, ${kernelSize}), ${wshape[2]});
53

54
                if(d1 < ${xshape[colDim]} && d1 >= 0) {
55

56
                  ch = int(float(pos)/ ${kernelSize}.);
57
                    innerDims = vec2(d0, d1);
58
                    result[${row * 2 + col}] = getChannel(
59
                      getA(0, ch, int(innerDims.x),
60
                      int(innerDims.y)), innerDims);
61
                }
62
              }
63
            }
64

65
          `;
66
    }
67
  }
68

69
  const shaderSource = `
70
      ${unpackChannel}
71

72
      void main() {
73
        ivec2 rc = getOutputCoords();
74
          vec4 result = vec4(0.0);
75
          int blockIndex, pos, offsetY, d0, offsetX, d1, ch;
76
          vec2 innerDims;
77
          ${unrolled}
78
          ${glsl.output} = result;
79
      }
80
            `;
81
  return {
82
    ...metadata,
83
    output: { dims: im2colShape, type: x.type, textureType: TextureType.packed },
84
    shaderSource,
85
    hasMain: true,
86
  };
87
};
88

89
export const createPackedIm2ColProgramInfoLoader = (
90
  inferenceHandler: WebGLInferenceHandler,
91
  x: Tensor,
92
  w: Tensor,
93
  outputShape: readonly number[],
94
  attributes: ConvAttributes,
95
): ProgramInfoLoader => {
96
  const metadata = createPackedIm2ColProgramMetadata(attributes.cacheKey);
97
  return {
98
    ...metadata,
99
    get: () => createPackedIm2ColProgramInfo(inferenceHandler, metadata, x, w, outputShape, attributes),
100
  };
101
};
102

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.