onnxruntime

Форк
0
/
conv-transpose.ts 
359 строк · 13.2 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { TensorView } from '../../tensor-view';
5
import { ComputeContext } from '../types';
6

7
import { createConv2DTransposeMatMulProgramInfo } from './3rd-party/conv_backprop_mm_webgpu';
8
import { createConvTranspose2DProgramInfo } from './3rd-party/conv_backprop_webgpu';
9
import { ConvAttributes } from './conv';
10
import { parseInternalActivationAttributes } from './fuse-utils';
11
import { createTransposeProgramInfo } from './transpose';
12

13
const computeTotalPad = (
14
  inDim: number,
15
  stride: number,
16
  adj: number,
17
  kernel: number,
18
  dilation: number,
19
  outSize: number,
20
) => (inDim - 1) * stride + adj + (kernel - 1) * dilation + 1 - outSize;
21

22
const distributePadding = (totalPad: number, autoPad: string, pads: number[], head: number, tail: number) => {
23
  const smallPad = Math.floor(totalPad / 2);
24
  if (autoPad === 'SAME_UPPER') {
25
    pads[head] = smallPad;
26
    pads[tail] = totalPad - smallPad;
27
  } else if (autoPad === 'SAME_LOWER') {
28
    pads[head] = totalPad - smallPad;
29
    pads[tail] = smallPad;
30
  }
31
};
32

33
const calculateOutputShapeAndPads = (
34
  inputShape: readonly number[],
35
  kernelShape: readonly number[],
36
  dilations: readonly number[],
37
  autoPad: string,
38
  group: number,
39
  pads: number[],
40
  strides: readonly number[],
41
  isChannelLast: boolean,
42
  outputPadding: number[],
43
  outputShape: number[],
44
) => {
45
  const spatialRank = inputShape.length - 2;
46
  const updateOutputShape = outputShape.length === 0;
47
  if (outputPadding.length === 0) {
48
    for (let i = 0; i < spatialRank; ++i) {
49
      outputPadding.push(0);
50
    }
51
  }
52
  const batchSize = inputShape[0];
53
  const outChannels = kernelShape[isChannelLast ? 3 : 1] * group;
54
  for (let i = 0, j = inputShape.length - spatialRank - (isChannelLast ? 1 : 0); i < spatialRank; ++i, ++j) {
55
    const inSize = inputShape[j];
56
    const outSize = updateOutputShape ? inSize * strides[i] : outputShape[i];
57
    const totalPad = computeTotalPad(inSize, strides[i], pads[i], kernelShape[j], dilations[i], outSize);
58
    distributePadding(totalPad, autoPad, pads, i, i + spatialRank);
59
    if (updateOutputShape) {
60
      outputShape.push(
61
        strides[i] * (inSize - 1) +
62
          outputPadding[i] +
63
          (kernelShape[j] - 1) * dilations[i] +
64
          1 -
65
          pads[i] -
66
          pads[i + spatialRank],
67
      );
68
    }
69
  }
70
  outputShape.splice(0, 0, batchSize);
71
  outputShape.splice(isChannelLast ? 3 : 1, 0, outChannels);
72
};
73

74
export interface ConvTransposeAttributes extends ConvAttributes {
75
  readonly outputPadding: readonly number[];
76
  readonly outputShape: readonly number[];
77
}
78

79
const getAdjustedConvTransposeAttributes = <T extends ConvTransposeAttributes>(
80
  attributes: T,
81
  inputs: readonly TensorView[],
82
): T => {
83
  const kernelShape = attributes.kernelShape.slice();
84
  // if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims
85
  if (attributes.kernelShape.length === 0 || attributes.kernelShape.reduce((a, b) => a * b, 1) === 0) {
86
    kernelShape.length = 0;
87
    for (let i = 2; i < inputs[1].dims.length; ++i) {
88
      kernelShape.push(inputs[1].dims[i]);
89
    }
90
  }
91
  const isChannelsLast = attributes.format === 'NHWC';
92
  kernelShape.splice(0, 0, inputs[1].dims[0]);
93
  kernelShape.splice(isChannelsLast ? 3 : 1, 0, inputs[1].dims[1]);
94

95
  const pads = attributes.pads.slice();
96
  const outputShape = attributes.outputShape.slice();
97
  const outputPadding = attributes.outputPadding.slice();
98
  const inputShape = inputs[0].dims;
99
  let dilations = attributes.dilations.slice();
100
  if (dilations.reduce((a, b) => a + b, 0) === 0) {
101
    const spatialRank = inputs[0].dims.length - 2;
102
    dilations = new Array(spatialRank).fill(1);
103
  }
104
  let strides = attributes.strides.slice();
105
  if (strides.reduce((a, b) => a + b, 0) === 0) {
106
    const spatialRank = inputs[0].dims.length - 2;
107
    strides = new Array(spatialRank).fill(1);
108
  }
109
  // If outputShape is not specified in the attributes of this op, infer it from the parameters
110
  // Similarly, automatically infer pads if not specified
111
  calculateOutputShapeAndPads(
112
    inputShape,
113
    kernelShape,
114
    dilations,
115
    attributes.autoPad,
116
    attributes.group,
117
    pads,
118
    strides,
119
    isChannelsLast,
120
    outputPadding,
121
    outputShape,
122
  );
123

124
  // always return a new object so does not modify the original attributes
125
  const newAttributes: T = Object.assign({}, attributes);
126
  Object.assign(newAttributes, { kernelShape, pads, outputPadding, outputShape, dilations, strides });
127
  return newAttributes;
128
};
129

130
export const parseConvTransposeAttributes = (attributes: Record<string, unknown>): ConvTransposeAttributes => {
131
  const activationAttributes = parseInternalActivationAttributes(attributes);
132
  // TODO : Make this generic enough to compute default attributes for multi-dimensional conv
133
  const format = attributes.format as 'NHWC' | 'NCHW';
134
  const autoPad = ['NOTSET', 'VALID', 'SAME_UPPER', 'SAME_LOWER'][
135
    typeof attributes.autoPad == 'undefined' ? 0 : (attributes.autoPad as number)
136
  ];
137
  const dilations = attributes.dilations as [number, number];
138
  const group = attributes.group as number;
139
  const kernelShape = attributes.kernelShape as [number, number];
140
  const pads = attributes.pads as [number, number, number, number];
141
  const strides = attributes.strides as [number, number];
142
  const wIsConst = (attributes.wIsConst as () => boolean)();
143
  const outputPadding = attributes.outputPadding as [number, number, number, number];
144
  const outputShape = attributes.outputShape as [number, number];
145
  return {
146
    autoPad,
147
    format,
148
    dilations,
149
    group,
150
    kernelShape,
151
    outputPadding,
152
    outputShape,
153
    pads,
154
    strides,
155
    wIsConst,
156
    ...activationAttributes,
157
    cacheKey: `${attributes.format};${activationAttributes.activation};`,
158
  };
159
};
160

161
const validateInputs = (inputs: readonly TensorView[], attributes: ConvTransposeAttributes): void => {
162
  // Refer to the below link for all input checks
163
  // https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConvTranspose
164
  if (!inputs || (inputs.length !== 2 && inputs.length !== 3)) {
165
    throw new Error('Conv requires 2 or 3 inputs');
166
  }
167

168
  // TODO : Need to add support for multi-dimensional conv
169
  if (inputs[0].dims.length !== 4 && inputs[0].dims.length !== 3) {
170
    throw new Error('currently only support 2-dimensional conv');
171
  }
172

173
  if (inputs[0].dims.length !== inputs[1].dims.length) {
174
    throw new Error('filter does not have same dimension as input');
175
  }
176

177
  // FILTER_IN_CHANNEL should be equal to DATA_CHANNEL
178
  const dataChannel = inputs[0].dims[attributes.format === 'NHWC' ? inputs[0].dims.length - 1 : 1];
179
  const filterInChannel = inputs[1].dims[0];
180
  if (dataChannel !== filterInChannel) {
181
    throw new Error('FILTER_IN_CHANNEL should be equal to DATA_CHANNEL');
182
  }
183

184
  const featureMaps = inputs[1].dims[1] * attributes.group;
185

186
  // if bias is provided it should be 1D and the number of elements should be equal to the number of feature maps
187
  if (inputs.length === 3 && (inputs[2].dims.length !== 1 || inputs[2].dims[0] !== featureMaps)) {
188
    throw new Error('invalid bias');
189
  }
190

191
  const spatialRank = inputs[0].dims.length - 2;
192
  const dilationsSet = attributes.dilations.reduce((a, b) => a + b, 0) > 0;
193
  // wrong dilations dimension
194
  if (dilationsSet && attributes.dilations.length !== spatialRank) {
195
    throw new Error(`dilations should be ${spatialRank}D`);
196
  }
197

198
  const stridesSet = attributes.strides.reduce((a, b) => a + b, 0) > 0;
199
  // Wrong strides dimension
200
  if (stridesSet && attributes.strides.length !== spatialRank) {
201
    throw new Error(`strides should be ${spatialRank}D`);
202
  }
203

204
  // Wrong pads dimension
205
  const padsSet = attributes.pads.reduce((a, b) => a + b, 0) > 0;
206
  if (padsSet && attributes.pads.length !== spatialRank * 2) {
207
    throw new Error(`pads should be ${spatialRank * 2}D`);
208
  }
209

210
  // Wrong output padding dimension
211
  if (attributes.outputPadding.length !== spatialRank && attributes.outputPadding.length !== 0) {
212
    throw new Error(`output_padding should be ${spatialRank}D`);
213
  }
214

215
  // if kernelShape is specified, it's data length must be 2 less than dims length of the weights tensor
216
  // (the first 2 dims are batch_size and channels)
217
  const kernelShapeSet = attributes.kernelShape.reduce((a, b) => a + b, 0) > 0;
218
  if (
219
    kernelShapeSet &&
220
    attributes.kernelShape.length !== 0 &&
221
    attributes.kernelShape.length !== inputs[1].dims.length - 2
222
  ) {
223
    throw new Error('invalid kernel shape');
224
  }
225

226
  // as with kernelShape, must have same number of spatial dims as input
227
  if (attributes.outputShape.length !== 0 && attributes.outputShape.length !== inputs[0].dims.length - 2) {
228
    throw new Error('invalid output shape');
229
  }
230
};
231

232
// for transposing weight tensor from [C, M/group, KH, KW] to [KH, KW, M/group, C]
233
const weightTransposePerm = [2, 3, 1, 0];
234

235
const convTranspose2d = (
236
  context: ComputeContext,
237
  inputs: readonly TensorView[],
238
  attributes: ConvTransposeAttributes,
239
): void => {
240
  const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs);
241
  const isChannelsLast = attributes.format === 'NHWC';
242
  const outputShape = adjustedAttributes.outputShape;
243
  const outChannels = outputShape[isChannelsLast ? 3 : 1];
244
  const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1];
245
  // Switch to naive method when outChannels and inputChannels are very small. It's because that in this case it's
246
  // not suitable for matmul version since matmul uses tile size 32x32 resulting the underlying execution unit
247
  // utilization rate is very low.
248
  if (adjustedAttributes.group !== 1 || (outChannels === 1 && inputChannels === 1)) {
249
    context.compute(createConvTranspose2DProgramInfo(inputs, adjustedAttributes));
250
    return;
251
  }
252
  const outHeight = outputShape[isChannelsLast ? 1 : 2];
253
  const outWidth = outputShape[isChannelsLast ? 2 : 3];
254
  const weightHeight = inputs[1].dims[2];
255
  const weightWidth = inputs[1].dims[3];
256

257
  const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels;
258
  const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth;
259
  const dimInner = weightHeight * weightWidth * inputChannels;
260

261
  const sequentialAccessByThreads = /* backend.adapterInfo.isIntel() */ true;
262

263
  // STEP.1: transpose weight
264
  const transposedWeight =
265
    (context.kernelCustomData.wT as TensorView | undefined) ??
266
    context.compute(createTransposeProgramInfo(inputs[1], weightTransposePerm), {
267
      inputs: [1],
268
      outputs: [attributes.wIsConst ? -2 : -1],
269
    })[0];
270
  if (attributes.wIsConst && !context.kernelCustomData.wT) {
271
    context.kernelCustomData.wT = transposedWeight;
272
  }
273

274
  // STEP.2: prepare reshaped inputs
275
  const convTransposeInputs = [inputs[0], transposedWeight];
276
  const hasBias = inputs.length === 3;
277
  if (hasBias) {
278
    if (!isChannelsLast && inputs[2].dims.length === 1) {
279
      convTransposeInputs.push(inputs[2].reshape([inputs[2].dims[0], 1, 1]));
280
    } else {
281
      convTransposeInputs.push(inputs[2]);
282
    }
283
  }
284

285
  // STEP.3: compute matmul
286
  context.compute(
287
    createConv2DTransposeMatMulProgramInfo(
288
      convTransposeInputs,
289
      adjustedAttributes,
290
      outputShape,
291
      dimAOuter,
292
      dimBOuter,
293
      dimInner,
294
      hasBias,
295
      sequentialAccessByThreads,
296
    ),
297
    { inputs: convTransposeInputs },
298
  );
299
};
300

301
const convTranspose1d = (context: ComputeContext, attributes: ConvTransposeAttributes): void => {
302
  // extend the input to 2D by adding H dimension
303
  const isChannelLast = attributes.format === 'NHWC';
304

305
  const inputs = [
306
    context.inputs[0].reshape(
307
      isChannelLast
308
        ? // [N, W, C] -> [N, H=1, W, C]
309
          [context.inputs[0].dims[0], 1, context.inputs[0].dims[1], context.inputs[0].dims[2]]
310
        : // [N, C, W] -> [N, C, H=1, W]
311
          [context.inputs[0].dims[0], context.inputs[0].dims[1], 1, context.inputs[0].dims[2]],
312
    ),
313
    //[FILTER_OUT_CHANNEL, FILTER_IN_CHANNEL, kW] -> [FILTER_OUT_CHANNEL, FILTER_IN_CHANNEL, kH=1, kW]
314
    context.inputs[1].reshape([context.inputs[1].dims[0], context.inputs[1].dims[1], 1, context.inputs[1].dims[2]]),
315
  ];
316
  if (context.inputs.length === 3) {
317
    inputs.push(context.inputs[2]);
318
  }
319
  let kernelShape = attributes.kernelShape;
320
  if (kernelShape.length === 0 || kernelShape[0] === 0) {
321
    kernelShape = [context.inputs[1].dims[2]];
322
  }
323
  let dilations = attributes.dilations;
324
  if (dilations.length === 0 || dilations[0] === 0) {
325
    dilations = [1];
326
  }
327
  let strides = attributes.strides;
328
  if (strides.length === 0 || strides[0] === 0) {
329
    strides = [1];
330
  }
331
  let pads = attributes.pads;
332
  if (pads.length === 0) {
333
    pads = [0, 0];
334
  }
335
  pads = [0, pads[0], 0, pads[1]];
336
  strides = [1].concat(strides);
337
  dilations = [1].concat(dilations);
338
  kernelShape = [1].concat(kernelShape);
339
  const adjustedAttributes = getAdjustedConvTransposeAttributes(
340
    { ...attributes, pads, strides, dilations, kernelShape },
341
    inputs,
342
  );
343
  context.compute(
344
    createConvTranspose2DProgramInfo(inputs, adjustedAttributes, (outputShape) =>
345
      isChannelLast
346
        ? [outputShape[0], outputShape[2], outputShape[3]]
347
        : [outputShape[0], outputShape[1], outputShape[3]],
348
    ),
349
  );
350
};
351

352
export const convTranspose = (context: ComputeContext, attributes: ConvTransposeAttributes): void => {
353
  validateInputs(context.inputs, attributes);
354
  if (context.inputs[0].dims.length === 3) {
355
    convTranspose1d(context, attributes);
356
  } else {
357
    convTranspose2d(context, context.inputs, attributes);
358
  }
359
};
360

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.