onnxruntime

Форк
0
308 строк · 11.1 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
5
import { InferenceHandler } from '../../../backend';
6
import { Graph } from '../../../graph';
7
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
8
import { Tensor } from '../../../tensor';
9
import { getGlsl } from '../glsl-source';
10
import { WebGLInferenceHandler } from '../inference-handler';
11
import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types';
12

13
import { ConvAttributes } from './conv';
14
import { getActivationSnippet, parseInternalActivationAttributes } from './fuse-utils';
15

16
const computeTotalPad = (
17
  inDim: number,
18
  stride: number,
19
  adj: number,
20
  kernel: number,
21
  dilation: number,
22
  outSize: number,
23
) => (inDim - 1) * stride + adj + (kernel - 1) * dilation + 1 - outSize;
24

25
const distributePadding = (totalPad: number, autoPad: string, pads: number[], head: number, tail: number) => {
26
  const smallPad = Math.floor(totalPad / 2);
27
  if (autoPad === 'SAME_UPPER') {
28
    pads[head] = smallPad;
29
    pads[tail] = totalPad - smallPad;
30
  } else if (autoPad === 'SAME_LOWER') {
31
    pads[head] = totalPad - smallPad;
32
    pads[tail] = smallPad;
33
  }
34
};
35

36
const calculateOutputShapeAndPads = (
37
  inputShape: readonly number[],
38
  kernelShape: readonly number[],
39
  dilations: readonly number[],
40
  autoPad: string,
41
  pads: number[],
42
  strides: readonly number[],
43
  outputPadding: readonly number[],
44
  outputShape: number[],
45
) => {
46
  const spatialRank = inputShape.length - 2;
47
  const updateShape = outputShape.length === 0;
48
  for (let i = 0; i < spatialRank; ++i) {
49
    const outSize = updateShape ? inputShape[i + 2] * strides[i] : outputShape[i];
50
    const totalPad = computeTotalPad(inputShape[i + 2], strides[i], pads[i], kernelShape[i], dilations[i], outSize);
51
    distributePadding(totalPad, autoPad, pads, i, i + spatialRank);
52
    if (updateShape) {
53
      outputShape.push(
54
        strides[i] * (inputShape[i + 2] - 1) +
55
          outputPadding[i] +
56
          (kernelShape[i] - 1) * dilations[i] +
57
          1 -
58
          pads[i] -
59
          pads[i + spatialRank],
60
      );
61
    }
62
  }
63
};
64

65
export interface ConvTransposeAttributes extends ConvAttributes {
66
  readonly outputPadding: readonly number[];
67
  readonly outputShape: readonly number[];
68
}
69

70
export const convTranspose: OperatorImplementation<ConvTransposeAttributes> = (
71
  inferenceHandler: InferenceHandler,
72
  inputs: Tensor[],
73
  attributes: ConvTransposeAttributes,
74
): Tensor[] => {
75
  validateInputs(inputs, attributes); // currently will fail if not convTranspose2D
76
  return convTranspose2d(inferenceHandler, inputs, attributes);
77
};
78

79
const convTranspose2d: OperatorImplementation<ConvTransposeAttributes> = (
80
  inferenceHandler: WebGLInferenceHandler,
81
  inputs: Tensor[],
82
  attributes: ConvTransposeAttributes,
83
): Tensor[] => {
84
  const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs);
85
  return [convTranspose2DUnpacked(inferenceHandler, inputs, adjustedAttributes)];
86
};
87

88
const createConvTransposeProgramMetadata = (hasBias: boolean, cacheHint: string) => ({
89
  name: 'ConvTranspose',
90
  inputNames: hasBias ? ['X', 'W', 'B'] : ['X', 'W'],
91
  inputTypes: hasBias
92
    ? [TextureType.unpacked, TextureType.unpacked, TextureType.unpacked]
93
    : [TextureType.unpacked, TextureType.unpacked],
94
  cacheHint,
95
});
96

97
const createUnpackedConvTransposeProgramInfo = (
98
  inferenceHandler: WebGLInferenceHandler,
99
  inputs: readonly Tensor[],
100
  metadata: ProgramMetadata,
101
  attributes: ConvTransposeAttributes,
102
): ProgramInfo => {
103
  const hasBias = inputs.length > 2;
104
  const valueInit = hasBias ? 'getB(output_channel)' : '0.0';
105
  const xShape = inputs[0].dims;
106
  const wShape = inputs[1].dims;
107
  const outputChannelsPerGroup = wShape[1];
108
  const inputChannelsPerGroup = wShape[0] / attributes.group;
109
  const outputShape = [inputs[0].dims[0], inputs[1].dims[1] * attributes.group, ...attributes.outputShape];
110
  const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
111
  const { activationFunction, applyActivation } = getActivationSnippet(attributes);
112

113
  const shaderSource = `
114
  const ivec2 strides = ivec2(${attributes.strides[0]}, ${attributes.strides[1]});
115
  const ivec2 pads = ivec2(${attributes.pads[0]}, ${attributes.pads[1]});
116
  ${activationFunction}
117
  void main() {
118
    ivec4 coords = getOutputCoords();
119
    int batch = coords.x;
120
    int output_channel = coords.y;
121

122
    ivec2 loc = coords.zw + pads;
123

124
    int group_id = output_channel / ${outputChannelsPerGroup};
125
    int wOutChannel = output_channel - group_id * ${outputChannelsPerGroup};
126

127
    float value = ${valueInit};
128
    for (int inChannelOffset = 0; inChannelOffset < ${inputChannelsPerGroup}; inChannelOffset++) {
129
      int input_channel = group_id * ${inputChannelsPerGroup} + inChannelOffset;
130
      for (int wWOff = 0; wWOff < ${wShape[2]}; wWOff++) {
131
        for (int wHOff = 0; wHOff < ${wShape[3]}; wHOff++) {
132
          ivec2 wOff = ivec2(wWOff * ${attributes.dilations[0]}, wHOff * ${attributes.dilations[1]});
133
          ivec2 wLoc = loc - wOff;
134
          ivec2 wLocIn = wLoc / strides;
135
          if (
136
            wLocIn * strides == wLoc &&
137
            wLocIn.x >= 0 && wLocIn.x < ${xShape[2]} &&
138
            wLocIn.y >= 0 && wLocIn.y < ${xShape[3]}
139
          ) {
140
            float xVal = getX(batch, input_channel, wLocIn.y, wLocIn.x);
141
            float wVal = getW(input_channel, wOutChannel, wHOff, wWOff);
142
            value += xVal * wVal;
143
          }
144
        }
145
      }
146
    }
147
    ${applyActivation}
148
    ${glsl.output} = vec4(value, .0, .0, .0);
149
  }
150
`;
151
  return {
152
    ...metadata,
153
    output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked },
154
    shaderSource,
155
    hasMain: true,
156
  };
157
};
158

159
const createUnpackedConvTransposeProgramInfoLoader = (
160
  inferenceHandler: WebGLInferenceHandler,
161
  inputs: readonly Tensor[],
162
  attributes: ConvTransposeAttributes,
163
): ProgramInfoLoader => {
164
  const metadata = createConvTransposeProgramMetadata(inputs.length > 2, attributes.cacheKey);
165
  return {
166
    ...metadata,
167
    get: () => createUnpackedConvTransposeProgramInfo(inferenceHandler, inputs, metadata, attributes),
168
  };
169
};
170

171
const convTranspose2DUnpacked = (
172
  inferenceHandler: WebGLInferenceHandler,
173
  inputs: readonly Tensor[],
174
  attributes: ConvTransposeAttributes,
175
): Tensor => {
176
  const result = inferenceHandler.run(
177
    createUnpackedConvTransposeProgramInfoLoader(inferenceHandler, inputs, attributes),
178
    inputs,
179
  );
180
  return result;
181
};
182

183
const getAdjustedConvTransposeAttributes = <T extends ConvTransposeAttributes>(attributes: T, inputs: Tensor[]): T => {
184
  const kernelShape = attributes.kernelShape.slice();
185
  // if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims
186
  if (attributes.kernelShape.length === 0) {
187
    for (let i = 2; i < inputs[1].dims.length; ++i) {
188
      kernelShape.push(inputs[1].dims[i]);
189
    }
190
  }
191

192
  const pads = attributes.pads.slice();
193
  const outputShape = attributes.outputShape.slice();
194
  const inputShape = inputs[0].dims;
195
  // If outputShape is not specified in the attributes of this op, infer it from the parameters
196
  // Similarly, automatically infer pads if not specified
197
  calculateOutputShapeAndPads(
198
    inputShape,
199
    kernelShape,
200
    attributes.dilations,
201
    attributes.autoPad,
202
    pads,
203
    attributes.strides,
204
    attributes.outputPadding,
205
    outputShape,
206
  );
207

208
  // always return a new object so does not modify the original attributes
209
  const newAttributes: T = Object.assign({}, attributes);
210
  Object.assign(newAttributes, { kernelShape, pads, outputShape, cacheKey: attributes.cacheKey });
211
  return newAttributes;
212
};
213

214
export const parseConvTransposeAttributes: OperatorInitialization<ConvTransposeAttributes> = (
215
  node: Graph.Node,
216
): ConvTransposeAttributes => {
217
  const attributes = node.attributes;
218
  const activationAttributes = parseInternalActivationAttributes(attributes);
219
  // TODO : Make this generic enough to compute default attributes for multi-dimensional conv
220
  const autoPad = attributes.getString('auto_pad', 'NOTSET');
221
  const dilations = attributes.getInts('dilations', [1, 1]);
222
  const group = attributes.getInt('group', 1);
223
  const kernelShape = attributes.getInts('kernel_shape', []);
224
  const outputPadding = attributes.getInts('output_padding', [0, 0]);
225
  const outputShape = attributes.getInts('output_shape', []);
226
  const pads = attributes.getInts('pads', [0, 0, 0, 0]);
227
  const strides = attributes.getInts('strides', [1, 1]);
228

229
  return createAttributeWithCacheKey({
230
    autoPad,
231
    dilations,
232
    group,
233
    kernelShape,
234
    outputPadding,
235
    outputShape,
236
    pads,
237
    strides,
238
    ...activationAttributes,
239
  });
240
};
241

242
const validateInputs = (inputs: Tensor[], attributes: ConvTransposeAttributes): void => {
243
  // Refer to the below link for all input checks
244
  // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
245
  if (!inputs || (inputs.length !== 2 && inputs.length !== 3)) {
246
    throw new Error('Conv requires 2 or 3 inputs');
247
  }
248

249
  // TODO : Need to add support for multi-dimensional conv
250
  if (inputs[0].dims.length !== 4 || inputs[1].dims.length !== 4) {
251
    throw new Error('currently only support 2-dimensional conv');
252
  }
253

254
  // FILTER_IN_CHANNEL should be equal to DATA_CHANNEL
255
  const dataChannel = inputs[0].dims[1];
256
  const filterInChannel = inputs[1].dims[0];
257
  if (dataChannel !== filterInChannel) {
258
    throw new Error('FILTER_IN_CHANNEL should be equal to DATA_CHANNEL');
259
  }
260

261
  const featureMaps = inputs[1].dims[1] * attributes.group;
262

263
  // if bias is provided it should be 1D and the number of elements should be equal to the number of feature maps
264
  if (inputs.length === 3 && (inputs[2].dims.length !== 1 || inputs[2].dims[0] !== featureMaps)) {
265
    throw new Error('invalid bias');
266
  }
267

268
  const spatialRank = inputs[0].dims.length - 2;
269
  // wrong dilations dimension
270
  if (attributes.dilations.length !== spatialRank) {
271
    throw new Error(`dilations should be ${spatialRank}D`);
272
  }
273

274
  // Wrong strides dimension
275
  if (attributes.strides.length !== spatialRank) {
276
    throw new Error(`strides should be ${spatialRank}D`);
277
  }
278

279
  // Wrong pads dimension
280
  if (attributes.pads.length !== spatialRank * 2) {
281
    throw new Error(`pads should be ${spatialRank * 2}D`);
282
  }
283

284
  // Wrong output padding dimension
285
  if (attributes.outputPadding.length !== spatialRank) {
286
    throw new Error(`output_padding should be ${spatialRank}D`);
287
  }
288

289
  // if kernelShape is specified, it's data length must be 2 less than dims length of the weights tensor
290
  // (the first 2 dims are batch_size and channels)
291
  if (attributes.kernelShape.length !== 0 && attributes.kernelShape.length !== inputs[1].dims.length - 2) {
292
    throw new Error('invalid kernel shape');
293
  }
294

295
  // as with kernelShape, must have same number of spatial dims as input
296
  if (attributes.outputShape.length !== 0 && attributes.outputShape.length !== inputs[0].dims.length - 2) {
297
    throw new Error('invalid output shape');
298
  }
299

300
  // TODO : Need to add support for float64
301
  if (inputs[0].type !== 'float32' || inputs[1].type !== 'float32') {
302
    throw new Error('ConvTranspose input(X,W) should be float tensor');
303
  }
304

305
  if (inputs.length === 3 && inputs[2].type !== 'float32') {
306
    throw new Error('ConvTranspose input(bias) should be float tensor');
307
  }
308
};
309

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.