onnxruntime

Форк
0
107 строк · 3.4 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { Tensor } from '../../../tensor';
5
import { WebGLInferenceHandler } from '../inference-handler';
6
import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types';
7

8
import { ConvAttributes } from './conv';
9

10
const createIm2ColProgramMetadata = (cacheHint: string) => ({
11
  name: 'Im2Col',
12
  inputNames: ['X'],
13
  inputTypes: [TextureType.unpacked],
14
  cacheHint,
15
});
16

17
const createIm2ColProgramInfo = (
18
  _inferenceHandler: WebGLInferenceHandler,
19
  metadata: ProgramMetadata,
20
  x: Tensor,
21
  w: Tensor,
22
  outputShape: readonly number[],
23
  attributes: ConvAttributes,
24
): ProgramInfo => {
25
  const xshape = x.dims;
26
  const wshape = w.dims;
27

28
  const rank = outputShape.length;
29
  const im2colDims = calculateIm2ColDims(xshape, wshape, outputShape, 4);
30

31
  const shaderSource = `
32
        const int XC = ${xshape[1]};
33
        const int XH = ${xshape[2]};
34
        const int XW = ${xshape[3]};
35
        const int KH = ${attributes.kernelShape[0]};
36
        const int KW = ${attributes.kernelShape[1]};
37
        const int dilationH = ${attributes.dilations[0]};
38
        const int dilationW = ${attributes.dilations[1]};
39
        const int strideH = ${attributes.strides[0]};
40
        const int strideW = ${attributes.strides[1]};
41
        const int padH = ${attributes.pads[0]};
42
        const int padW = ${attributes.pads[1]};
43
        const int KHKW = KH*KW;
44
        const int XCKHKW = XC * KHKW;
45
        const int outputChannels = 4;
46
        vec4 process(int indices[${rank}]) {
47
          int b  = indices[0]; // batch size
48
          int oh = indices[1] * strideH - padH; //output height
49
          int ow = indices[2] * strideW - padW; //output width
50
          int p = indices[3] * outputChannels; //patch
51
          vec4 value = vec4(0.0);
52
          for(int i=0; i < outputChannels; ++i) {
53
            if(p < XCKHKW) {
54
              int patchC = p / KHKW;
55
              int patchH = (p - patchC*KHKW) / KW;
56
              int patchW = (p - patchC*KHKW) - patchH * KW;
57
              int xh2 = oh + patchH * dilationH;
58
              int xw2 = ow + patchW * dilationW;
59
              int x[${xshape.length}];
60
              x[0] = b;
61
              x[1] = patchC;
62
              x[2] = xh2;
63
              x[3] = xw2;
64
              if(xh2 >= 0 &&
65
                  xh2 < XH &&
66
                  xw2 >= 0 &&
67
                  xw2 < XW) {
68
                value[i] = _X(x);
69
              }
70
            }
71
            ++p;
72
          }
73
          return value;
74
        }
75
        `;
76
  return {
77
    ...metadata,
78
    output: { dims: im2colDims, type: x.type, textureType: TextureType.packedLastDimension },
79
    shaderSource,
80
  };
81
};
82

83
export const createIm2ColProgramInfoLoader = (
84
  inferenceHandler: WebGLInferenceHandler,
85
  x: Tensor,
86
  w: Tensor,
87
  outputShape: readonly number[],
88
  attributes: ConvAttributes,
89
): ProgramInfoLoader => {
90
  const metadata = createIm2ColProgramMetadata(attributes.cacheKey);
91
  return {
92
    ...metadata,
93
    get: () => createIm2ColProgramInfo(inferenceHandler, metadata, x, w, outputShape, attributes),
94
  };
95
};
96

97
export const calculateIm2ColDims = (
98
  inputShape: readonly number[],
99
  kernelShape: readonly number[],
100
  outputShape: readonly number[],
101
  channels = 4,
102
): number[] => [
103
  outputShape[0],
104
  outputShape[2],
105
  outputShape[3],
106
  Math.ceil((inputShape[1] * kernelShape[2] * kernelShape[3]) / channels),
107
];
108

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.