onnxruntime

Форк
0
/
conv-grouped.ts 
249 строк · 10.2 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types';
8

9
import {
10
  createTensorShapeVariables,
11
  getMaxComponents,
12
  inputVariable,
13
  outputVariable,
14
  ShaderHelper,
15
  tensorTypeToWsglStorageType,
16
  UniformsArrayType,
17
} from './common';
18
import { calculateOutputShape, ConvAttributes } from './conv';
19
import { appendActivationUniforms, appendActivationUniformsData, getActivationSnippet } from './fuse-utils';
20

21
/**
22
 * naive grouped conv implementation, supports 1d/2d conv
23
 * @param squeezeOutputShapeFunction - an optional function to squeeze the output shape, only used in conv1d
24
 */
25
export const createGroupedConvProgramInfo = (
26
  inputs: readonly TensorView[],
27
  attributes: ConvAttributes,
28
  squeezeOutputShapeFunction?: (shape: readonly number[]) => number[],
29
): ProgramInfo => {
30
  const hasBias = inputs.length > 2;
31
  const processBias = hasBias ? 'value += b[output_channel];' : '';
32
  const xShape = inputs[0].dims;
33
  const wShape = inputs[1].dims;
34
  const outputChannelsPerGroup = wShape[0] / attributes.group;
35

36
  const isChannelLast = attributes.format === 'NHWC';
37
  const outputShape = calculateOutputShape(
38
    xShape,
39
    wShape,
40
    attributes.dilations,
41
    attributes.pads,
42
    attributes.strides,
43
    isChannelLast,
44
  );
45
  const outputSize = ShapeUtil.size(outputShape);
46

47
  const programUniforms: ProgramUniform[] = [
48
    { type: DataType.uint32, data: outputSize },
49
    { type: DataType.uint32, data: attributes.dilations },
50
    { type: DataType.uint32, data: [attributes.strides[0], attributes.strides[1]] },
51
    { type: DataType.uint32, data: [attributes.pads[0], attributes.pads[1]] },
52
    { type: DataType.uint32, data: outputChannelsPerGroup },
53
  ];
54
  appendActivationUniformsData(attributes, programUniforms);
55
  programUniforms.push(...createTensorShapeVariables(xShape, wShape));
56
  const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
57
  if (hasBias) {
58
    programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
59
    inputDependencies.push('rank');
60
  }
61
  programUniforms.push(...createTensorShapeVariables(outputShape));
62

63
  const getShaderSource = (shaderHelper: ShaderHelper) => {
64
    const output = outputVariable('output', inputs[0].dataType, outputShape.length);
65
    const baseType = tensorTypeToWsglStorageType(output.type.tensor);
66
    const applyActivation = getActivationSnippet(attributes, output.type.value, baseType);
67
    const x = inputVariable('x', inputs[0].dataType, xShape.length);
68
    const w = inputVariable('w', inputs[1].dataType, wShape.length);
69
    const inputVars = [x, w];
70
    if (hasBias) {
71
      inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims.length));
72
    }
73

74
    const uniforms: UniformsArrayType = [
75
      { name: 'output_size', type: 'u32' },
76
      { name: 'dilations', type: 'u32', length: attributes.dilations.length },
77
      { name: 'strides', type: 'u32', length: 2 },
78
      { name: 'pads', type: 'u32', length: 2 },
79
      { name: 'output_channels_per_group', type: 'u32' },
80
    ];
81
    appendActivationUniforms(attributes, uniforms);
82
    return `
83
  ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, output)}
84

85
  ${shaderHelper.mainStart()}
86
    ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
87

88
    let outputIndices = ${output.offsetToIndices('global_idx')};
89
    let batch: u32 = outputIndices[0];
90
    let output_channel: u32 = outputIndices[${isChannelLast ? 3 : 1}];
91
    let xRCCorner: vec2<u32> = vec2<u32>(outputIndices[${isChannelLast ? 1 : 2}], outputIndices[${
92
      isChannelLast ? 2 : 3
93
    }]) * uniforms.strides - uniforms.pads;
94
    let group_id: u32 = output_channel / uniforms.output_channels_per_group;
95

96
    var value: ${output.type.value} = ${output.type.value}(0);
97
    for (var wInChannel: u32 = 0u; wInChannel < uniforms.w_shape[1]; wInChannel++) {
98
      let input_channel = group_id * uniforms.w_shape[1] + wInChannel;
99
      for (var wHeight: u32 = 0u; wHeight < uniforms.w_shape[2]; wHeight++) {
100
        let xHeight = xRCCorner.x + wHeight * uniforms.dilations[0];
101

102
        if (xHeight < 0u || xHeight >= uniforms.x_shape[${isChannelLast ? 1 : 2}]) {
103
          continue;
104
        }
105

106
        for (var wWidth: u32 = 0u; wWidth < uniforms.w_shape[3]; wWidth++) {
107
          let xWidth = xRCCorner.y + wWidth * uniforms.dilations[1];
108
          if (xWidth < 0u || xWidth >= uniforms.x_shape[${isChannelLast ? 2 : 3}]) {
109
            continue;
110
          }
111

112
          let xVal = ${
113
            isChannelLast
114
              ? x.get('batch', 'xHeight', 'xWidth', 'input_channel')
115
              : x.get('batch', 'input_channel', 'xHeight', 'xWidth')
116
          };
117
          let wVal = ${w.get('output_channel', 'wInChannel', 'wHeight', 'wWidth')};
118
          value += xVal*wVal;
119
        }
120
      }
121
    }
122
    ${processBias}
123
    ${applyActivation}
124
    ${output.setByOffset('global_idx', 'value')}
125
  }`;
126
  };
127
  return {
128
    name: 'GroupedConv',
129
    shaderCache: { hint: attributes.cacheKey, inputDependencies },
130
    getRunData: () => ({
131
      outputs: [
132
        {
133
          dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape,
134
          dataType: inputs[0].dataType,
135
        },
136
      ],
137
      dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
138
      programUniforms,
139
    }),
140
    getShaderSource,
141
  };
142
};
143

144
export const createGroupedConvVectorizeProgramInfo = (
145
  inputs: readonly TensorView[],
146
  attributes: ConvAttributes,
147
  outputShape: readonly number[],
148
  squeezeOutputShapeFunction?: (shape: readonly number[]) => number[],
149
): ProgramInfo => {
150
  const hasBias = inputs.length > 2;
151
  const components = getMaxComponents(outputShape[3]);
152
  const outputNumber = getMaxComponents(outputShape[2]);
153
  const outputSize = ShapeUtil.size(outputShape) / components / outputNumber;
154
  const xShape = [inputs[0].dims[0], inputs[0].dims[1], inputs[0].dims[2], inputs[0].dims[3] / components];
155
  const wShape = [inputs[1].dims[0], inputs[1].dims[1], inputs[1].dims[2], inputs[1].dims[3] / components];
156
  const outputShapeInShader = [outputShape[0], outputShape[1], outputShape[2], outputShape[3] / components];
157

158
  const programUniforms: ProgramUniform[] = [
159
    { type: DataType.uint32, data: outputSize },
160
    { type: DataType.int32, data: [attributes.strides[0], attributes.strides[1]] },
161
    { type: DataType.int32, data: [attributes.pads[0], attributes.pads[1]] },
162
  ];
163
  appendActivationUniformsData(attributes, programUniforms);
164
  programUniforms.push(...createTensorShapeVariables(xShape, wShape, outputShapeInShader));
165
  const xNumber = (outputNumber - 1) * attributes.strides[1] + wShape[1];
166
  const getShaderSource = (shaderHelper: ShaderHelper) => {
167
    const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components);
168
    const baseType = tensorTypeToWsglStorageType(output.type.tensor);
169
    const applyActivation = getActivationSnippet(attributes, output.type.value, baseType);
170
    const x = inputVariable('x', inputs[0].dataType, xShape.length, components);
171
    const w = inputVariable('w', inputs[1].dataType, wShape.length, components);
172
    const inputVars = [x, w];
173
    if (hasBias) {
174
      inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims, components));
175
    }
176
    const processBias = hasBias ? 'value += b[output_channel];' : '';
177
    const uniforms: UniformsArrayType = [
178
      { name: 'output_size', type: 'u32' },
179
      { name: 'strides', type: 'i32', length: 2 },
180
      { name: 'pads', type: 'i32', length: 2 },
181
    ];
182
    appendActivationUniforms(attributes, uniforms);
183
    return `
184
  ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, output)}
185
  ${shaderHelper.mainStart()}
186
    ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
187
    let width0 = uniforms.output_shape[3];
188
    let output_channel = global_idx % width0;
189
    var index1 = global_idx / width0;
190
    let width1 = uniforms.output_shape[2] / ${outputNumber}u;
191
    let col = (index1 % width1) * ${outputNumber}u;
192
    index1 = index1 / width1;
193
    let row = index1 % uniforms.output_shape[1];
194
    let batch = index1 / uniforms.output_shape[1];
195

196
    let x_corner = vec2<i32>(i32(row), i32(col)) * uniforms.strides - uniforms.pads;
197

198
    var x_vals: array<${x.type.value}, ${xNumber}>;
199
    var values: array<${output.type.value}, ${outputNumber}>;
200
    let input_channel = output_channel;
201
    // Use constant instead of uniform can give better performance for w's height/width.
202
    for (var w_height: u32 = 0u; w_height < ${wShape[0]}; w_height++) {
203
      let x_height = x_corner.x + i32(w_height);
204
      if (x_height >= 0 && u32(x_height) < uniforms.x_shape[1]) {
205
        for (var i = 0; i < ${xNumber}; i++) {
206
          let x_width = x_corner.y + i;
207
          if (x_width >= 0 && u32(x_width) < uniforms.x_shape[2]) {
208
            x_vals[i] = ${x.get('batch', 'u32(x_height)', 'u32(x_width)', 'input_channel')};
209
          } else {
210
            x_vals[i] = ${x.type.value}(0);
211
          }
212
        }
213
        for (var w_width: u32 = 0u; w_width < ${wShape[1]}; w_width++) {
214
          let w_val = ${w.get('w_height', 'w_width', '0', 'output_channel')};
215
          for (var i = 0u; i < ${outputNumber}u; i++) {
216
            values[i] = fma(x_vals[i * u32(uniforms.strides[1]) + w_width], w_val, values[i]);
217
          }
218
        }
219
      }
220
    }
221

222
    for (var i = 0u; i < ${outputNumber}u; i++) {
223
      var value = values[i];
224
      ${processBias}
225
      ${applyActivation}
226
      ${output.set('batch', 'row', 'col + i', 'output_channel', 'value')};
227
    }
228
  }`;
229
  };
230

231
  return {
232
    name: 'GroupedConv-Vectorize',
233
    shaderCache: {
234
      hint: `${attributes.cacheKey};${components};${outputNumber};${xNumber};${wShape[0]};${wShape[1]}`,
235
      inputDependencies: hasBias ? ['rank', 'rank', 'type'] : ['rank', 'rank'],
236
    },
237
    getRunData: () => ({
238
      outputs: [
239
        {
240
          dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape,
241
          dataType: inputs[0].dataType,
242
        },
243
      ],
244
      dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
245
      programUniforms,
246
    }),
247
    getShaderSource,
248
  };
249
};
250

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.