1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { TensorView } from '../../tensor-view';
5
import { ComputeContext } from '../types';
7
import { createConv2DTransposeMatMulProgramInfo } from './3rd-party/conv_backprop_mm_webgpu';
8
import { createConvTranspose2DProgramInfo } from './3rd-party/conv_backprop_webgpu';
9
import { ConvAttributes } from './conv';
10
import { parseInternalActivationAttributes } from './fuse-utils';
11
import { createTransposeProgramInfo } from './transpose';
13
const computeTotalPad = (
20
) => (inDim - 1) * stride + adj + (kernel - 1) * dilation + 1 - outSize;
22
const distributePadding = (totalPad: number, autoPad: string, pads: number[], head: number, tail: number) => {
23
const smallPad = Math.floor(totalPad / 2);
24
if (autoPad === 'SAME_UPPER') {
25
pads[head] = smallPad;
26
pads[tail] = totalPad - smallPad;
27
} else if (autoPad === 'SAME_LOWER') {
28
pads[head] = totalPad - smallPad;
29
pads[tail] = smallPad;
33
const calculateOutputShapeAndPads = (
34
inputShape: readonly number[],
35
kernelShape: readonly number[],
36
dilations: readonly number[],
40
strides: readonly number[],
41
isChannelLast: boolean,
42
outputPadding: number[],
43
outputShape: number[],
45
const spatialRank = inputShape.length - 2;
46
const updateOutputShape = outputShape.length === 0;
47
if (outputPadding.length === 0) {
48
for (let i = 0; i < spatialRank; ++i) {
49
outputPadding.push(0);
52
const batchSize = inputShape[0];
53
const outChannels = kernelShape[isChannelLast ? 3 : 1] * group;
54
for (let i = 0, j = inputShape.length - spatialRank - (isChannelLast ? 1 : 0); i < spatialRank; ++i, ++j) {
55
const inSize = inputShape[j];
56
const outSize = updateOutputShape ? inSize * strides[i] : outputShape[i];
57
const totalPad = computeTotalPad(inSize, strides[i], pads[i], kernelShape[j], dilations[i], outSize);
58
distributePadding(totalPad, autoPad, pads, i, i + spatialRank);
59
if (updateOutputShape) {
61
strides[i] * (inSize - 1) +
63
(kernelShape[j] - 1) * dilations[i] +
66
pads[i + spatialRank],
70
outputShape.splice(0, 0, batchSize);
71
outputShape.splice(isChannelLast ? 3 : 1, 0, outChannels);
74
export interface ConvTransposeAttributes extends ConvAttributes {
75
readonly outputPadding: readonly number[];
76
readonly outputShape: readonly number[];
79
const getAdjustedConvTransposeAttributes = <T extends ConvTransposeAttributes>(
81
inputs: readonly TensorView[],
83
const kernelShape = attributes.kernelShape.slice();
84
// if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims
85
if (attributes.kernelShape.length === 0 || attributes.kernelShape.reduce((a, b) => a * b, 1) === 0) {
86
kernelShape.length = 0;
87
for (let i = 2; i < inputs[1].dims.length; ++i) {
88
kernelShape.push(inputs[1].dims[i]);
91
const isChannelsLast = attributes.format === 'NHWC';
92
kernelShape.splice(0, 0, inputs[1].dims[0]);
93
kernelShape.splice(isChannelsLast ? 3 : 1, 0, inputs[1].dims[1]);
95
const pads = attributes.pads.slice();
96
const outputShape = attributes.outputShape.slice();
97
const outputPadding = attributes.outputPadding.slice();
98
const inputShape = inputs[0].dims;
99
let dilations = attributes.dilations.slice();
100
if (dilations.reduce((a, b) => a + b, 0) === 0) {
101
const spatialRank = inputs[0].dims.length - 2;
102
dilations = new Array(spatialRank).fill(1);
104
let strides = attributes.strides.slice();
105
if (strides.reduce((a, b) => a + b, 0) === 0) {
106
const spatialRank = inputs[0].dims.length - 2;
107
strides = new Array(spatialRank).fill(1);
109
// If outputShape is not specified in the attributes of this op, infer it from the parameters
110
// Similarly, automatically infer pads if not specified
111
calculateOutputShapeAndPads(
124
// always return a new object so does not modify the original attributes
125
const newAttributes: T = Object.assign({}, attributes);
126
Object.assign(newAttributes, { kernelShape, pads, outputPadding, outputShape, dilations, strides });
127
return newAttributes;
130
export const parseConvTransposeAttributes = (attributes: Record<string, unknown>): ConvTransposeAttributes => {
131
const activationAttributes = parseInternalActivationAttributes(attributes);
132
// TODO : Make this generic enough to compute default attributes for multi-dimensional conv
133
const format = attributes.format as 'NHWC' | 'NCHW';
134
const autoPad = ['NOTSET', 'VALID', 'SAME_UPPER', 'SAME_LOWER'][
135
typeof attributes.autoPad == 'undefined' ? 0 : (attributes.autoPad as number)
137
const dilations = attributes.dilations as [number, number];
138
const group = attributes.group as number;
139
const kernelShape = attributes.kernelShape as [number, number];
140
const pads = attributes.pads as [number, number, number, number];
141
const strides = attributes.strides as [number, number];
142
const wIsConst = (attributes.wIsConst as () => boolean)();
143
const outputPadding = attributes.outputPadding as [number, number, number, number];
144
const outputShape = attributes.outputShape as [number, number];
156
...activationAttributes,
157
cacheKey: `${attributes.format};${activationAttributes.activation};`,
161
const validateInputs = (inputs: readonly TensorView[], attributes: ConvTransposeAttributes): void => {
162
// Refer to the below link for all input checks
163
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConvTranspose
164
if (!inputs || (inputs.length !== 2 && inputs.length !== 3)) {
165
throw new Error('Conv requires 2 or 3 inputs');
168
// TODO : Need to add support for multi-dimensional conv
169
if (inputs[0].dims.length !== 4 && inputs[0].dims.length !== 3) {
170
throw new Error('currently only support 2-dimensional conv');
173
if (inputs[0].dims.length !== inputs[1].dims.length) {
174
throw new Error('filter does not have same dimension as input');
177
// FILTER_IN_CHANNEL should be equal to DATA_CHANNEL
178
const dataChannel = inputs[0].dims[attributes.format === 'NHWC' ? inputs[0].dims.length - 1 : 1];
179
const filterInChannel = inputs[1].dims[0];
180
if (dataChannel !== filterInChannel) {
181
throw new Error('FILTER_IN_CHANNEL should be equal to DATA_CHANNEL');
184
const featureMaps = inputs[1].dims[1] * attributes.group;
186
// if bias is provided it should be 1D and the number of elements should be equal to the number of feature maps
187
if (inputs.length === 3 && (inputs[2].dims.length !== 1 || inputs[2].dims[0] !== featureMaps)) {
188
throw new Error('invalid bias');
191
const spatialRank = inputs[0].dims.length - 2;
192
const dilationsSet = attributes.dilations.reduce((a, b) => a + b, 0) > 0;
193
// wrong dilations dimension
194
if (dilationsSet && attributes.dilations.length !== spatialRank) {
195
throw new Error(`dilations should be ${spatialRank}D`);
198
const stridesSet = attributes.strides.reduce((a, b) => a + b, 0) > 0;
199
// Wrong strides dimension
200
if (stridesSet && attributes.strides.length !== spatialRank) {
201
throw new Error(`strides should be ${spatialRank}D`);
204
// Wrong pads dimension
205
const padsSet = attributes.pads.reduce((a, b) => a + b, 0) > 0;
206
if (padsSet && attributes.pads.length !== spatialRank * 2) {
207
throw new Error(`pads should be ${spatialRank * 2}D`);
210
// Wrong output padding dimension
211
if (attributes.outputPadding.length !== spatialRank && attributes.outputPadding.length !== 0) {
212
throw new Error(`output_padding should be ${spatialRank}D`);
215
// if kernelShape is specified, it's data length must be 2 less than dims length of the weights tensor
216
// (the first 2 dims are batch_size and channels)
217
const kernelShapeSet = attributes.kernelShape.reduce((a, b) => a + b, 0) > 0;
220
attributes.kernelShape.length !== 0 &&
221
attributes.kernelShape.length !== inputs[1].dims.length - 2
223
throw new Error('invalid kernel shape');
226
// as with kernelShape, must have same number of spatial dims as input
227
if (attributes.outputShape.length !== 0 && attributes.outputShape.length !== inputs[0].dims.length - 2) {
228
throw new Error('invalid output shape');
232
// for transposing weight tensor from [C, M/group, KH, KW] to [KH, KW, M/group, C]
233
const weightTransposePerm = [2, 3, 1, 0];
235
const convTranspose2d = (
236
context: ComputeContext,
237
inputs: readonly TensorView[],
238
attributes: ConvTransposeAttributes,
240
const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs);
241
const isChannelsLast = attributes.format === 'NHWC';
242
const outputShape = adjustedAttributes.outputShape;
243
const outChannels = outputShape[isChannelsLast ? 3 : 1];
244
const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1];
245
// Switch to naive method when outChannels and inputChannels are very small. It's because that in this case it's
246
// not suitable for matmul version since matmul uses tile size 32x32 resulting the underlying execution unit
247
// utilization rate is very low.
248
if (adjustedAttributes.group !== 1 || (outChannels === 1 && inputChannels === 1)) {
249
context.compute(createConvTranspose2DProgramInfo(inputs, adjustedAttributes));
252
const outHeight = outputShape[isChannelsLast ? 1 : 2];
253
const outWidth = outputShape[isChannelsLast ? 2 : 3];
254
const weightHeight = inputs[1].dims[2];
255
const weightWidth = inputs[1].dims[3];
257
const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels;
258
const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth;
259
const dimInner = weightHeight * weightWidth * inputChannels;
261
const sequentialAccessByThreads = /* backend.adapterInfo.isIntel() */ true;
263
// STEP.1: transpose weight
264
const transposedWeight =
265
(context.kernelCustomData.wT as TensorView | undefined) ??
266
context.compute(createTransposeProgramInfo(inputs[1], weightTransposePerm), {
268
outputs: [attributes.wIsConst ? -2 : -1],
270
if (attributes.wIsConst && !context.kernelCustomData.wT) {
271
context.kernelCustomData.wT = transposedWeight;
274
// STEP.2: prepare reshaped inputs
275
const convTransposeInputs = [inputs[0], transposedWeight];
276
const hasBias = inputs.length === 3;
278
if (!isChannelsLast && inputs[2].dims.length === 1) {
279
convTransposeInputs.push(inputs[2].reshape([inputs[2].dims[0], 1, 1]));
281
convTransposeInputs.push(inputs[2]);
285
// STEP.3: compute matmul
287
createConv2DTransposeMatMulProgramInfo(
295
sequentialAccessByThreads,
297
{ inputs: convTransposeInputs },
301
const convTranspose1d = (context: ComputeContext, attributes: ConvTransposeAttributes): void => {
302
// extend the input to 2D by adding H dimension
303
const isChannelLast = attributes.format === 'NHWC';
306
context.inputs[0].reshape(
308
? // [N, W, C] -> [N, H=1, W, C]
309
[context.inputs[0].dims[0], 1, context.inputs[0].dims[1], context.inputs[0].dims[2]]
310
: // [N, C, W] -> [N, C, H=1, W]
311
[context.inputs[0].dims[0], context.inputs[0].dims[1], 1, context.inputs[0].dims[2]],
313
//[FILTER_OUT_CHANNEL, FILTER_IN_CHANNEL, kW] -> [FILTER_OUT_CHANNEL, FILTER_IN_CHANNEL, kH=1, kW]
314
context.inputs[1].reshape([context.inputs[1].dims[0], context.inputs[1].dims[1], 1, context.inputs[1].dims[2]]),
316
if (context.inputs.length === 3) {
317
inputs.push(context.inputs[2]);
319
let kernelShape = attributes.kernelShape;
320
if (kernelShape.length === 0 || kernelShape[0] === 0) {
321
kernelShape = [context.inputs[1].dims[2]];
323
let dilations = attributes.dilations;
324
if (dilations.length === 0 || dilations[0] === 0) {
327
let strides = attributes.strides;
328
if (strides.length === 0 || strides[0] === 0) {
331
let pads = attributes.pads;
332
if (pads.length === 0) {
335
pads = [0, pads[0], 0, pads[1]];
336
strides = [1].concat(strides);
337
dilations = [1].concat(dilations);
338
kernelShape = [1].concat(kernelShape);
339
const adjustedAttributes = getAdjustedConvTransposeAttributes(
340
{ ...attributes, pads, strides, dilations, kernelShape },
344
createConvTranspose2DProgramInfo(inputs, adjustedAttributes, (outputShape) =>
346
? [outputShape[0], outputShape[2], outputShape[3]]
347
: [outputShape[0], outputShape[1], outputShape[3]],
352
export const convTranspose = (context: ComputeContext, attributes: ConvTransposeAttributes): void => {
353
validateInputs(context.inputs, attributes);
354
if (context.inputs[0].dims.length === 3) {
355
convTranspose1d(context, attributes);
357
convTranspose2d(context, context.inputs, attributes);