1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { TensorView } from '../../tensor-view';
5
import { PoolConvUtil } from '../../util';
6
import { AttributeWithCacheKey } from '../attribute-with-cache-key';
7
import { ComputeContext } from '../types';
9
import { createConv2DMatMulProgramInfo } from './3rd-party/conv2d_mm_webgpu';
10
import { computeConv3DInfo, createConv3DNaiveProgramInfo } from './3rd-party/conv3d_naive_webgpu';
11
import { createMatmulProgramInfo } from './3rd-party/matmul_packed_webgpu';
12
import { createGroupedConvProgramInfo, createGroupedConvVectorizeProgramInfo } from './conv-grouped';
13
import { InternalActivationAttributes, parseInternalActivationAttributes } from './fuse-utils';
14
import { createNaiveMatmulProgramInfo } from './matmul';
15
import { createTransposeProgramInfo } from './transpose';
17
export const calculateOutputShape = (
18
inputShape: readonly number[],
19
kernelShape: readonly number[],
20
dilations: readonly number[],
21
adjustPads: readonly number[],
22
strides: readonly number[],
23
isChannelLast: boolean,
25
const batchSize = inputShape[0];
26
const inputSpatialShape = inputShape.slice(isChannelLast ? 1 : 2, isChannelLast ? 3 : 4);
27
const spatialRank = inputSpatialShape.length;
28
const outChannels = kernelShape[0];
29
const kernelSpatialShape = kernelShape.slice(2);
30
const dilatedKernelShape = kernelSpatialShape.map((v, i) => v + (v - 1) * (dilations[i] - 1));
31
const inputSpatialShapeWithPad = inputSpatialShape.map((v, i) => v + adjustPads[i] + adjustPads[i + spatialRank]);
32
const outputShape = inputSpatialShapeWithPad.map((v, i) =>
33
Math.floor((v - dilatedKernelShape[i] + strides[i]) / strides[i]),
35
outputShape.splice(0, 0, batchSize);
36
outputShape.splice(isChannelLast ? 3 : 1, 0, outChannels);
40
export interface ConvAttributes extends InternalActivationAttributes, AttributeWithCacheKey {
41
readonly autoPad: string;
42
readonly dilations: readonly number[];
43
readonly format: 'NHWC' | 'NCHW';
44
readonly group: number;
45
readonly kernelShape: readonly number[];
46
readonly pads: readonly number[];
47
readonly strides: readonly number[];
48
readonly wIsConst: boolean;
51
// for transposing weight tensor from [M, C/group, KH, KW] to [KH, KW, C/group, M]
52
const weightTransposeAttribute = [2, 3, 1, 0];
54
const validateInputs = (inputs: readonly TensorView[], attributes: ConvAttributes): void => {
55
// Refer to the below link for all input checks
56
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#Conv
57
if (!inputs || (inputs.length !== 2 && inputs.length !== 3)) {
58
throw new Error('Conv requires 2 or 3 inputs');
61
if (inputs[0].dims.length > 5) {
62
throw new Error('greater than 5D is not supported');
65
if (inputs[0].dims.length !== inputs[1].dims.length) {
66
throw new Error('filter does not have same dimension as input');
69
// FILTER_IN_CHANNEL should be equal to DATA_CHANNEL
70
const dataChannel = inputs[0].dims[attributes.format === 'NHWC' ? inputs[0].dims.length - 1 : 1];
71
const filterInChannel = inputs[1].dims[1] * attributes.group;
72
if (dataChannel !== filterInChannel) {
73
throw new Error('FILTER_IN_CHANNEL should be equal to DATA_CHANNEL');
76
// if bias is provided it should be 1D and the number of elements should be equal to the number of feature maps
77
if (inputs.length === 3 && (inputs[2].dims.length !== 1 || inputs[1].dims[0] !== inputs[2].dims[0])) {
78
throw new Error('invalid bias');
81
const spatialRank = inputs[0].dims.length - 2;
82
// wrong dilations dimension
83
if (attributes.dilations.length !== spatialRank) {
84
throw new Error(`dilations should be ${spatialRank}D`);
87
// Wrong strides dimension
88
if (attributes.strides.length !== spatialRank) {
89
throw new Error(`strides should be ${spatialRank}D`);
92
// Wrong pads dimension
93
if (attributes.pads.length !== spatialRank * 2) {
94
throw new Error(`pads should be ${spatialRank * 2}D`);
97
// if kernelShape is specified, it's data length must be 2 less than dims length of the weights tensor
98
// (the first 2 dims are batch_size and channels)
99
if (attributes.kernelShape.length !== 0 && attributes.kernelShape.length !== inputs[1].dims.length - 2) {
100
throw new Error('invalid kernel shape');
104
const getAdjustedConvAttributes = <T extends ConvAttributes>(attributes: T, inputs: readonly TensorView[]): T => {
105
const kernelShape = attributes.kernelShape.slice();
106
// if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims
107
for (let i = 2; i < inputs[1].dims.length; ++i) {
108
if (kernelShape[i - 2] === 0) {
109
kernelShape[i - 2] = inputs[1].dims[i];
112
const pads = attributes.pads.slice();
113
PoolConvUtil.adjustPadsBasedOnAutoPad(
116
attributes.dilations,
119
attributes.format === 'NHWC',
123
// always return a new object so does not modify the original attributes
124
const newAttributes: T = Object.assign({}, attributes);
125
Object.assign(newAttributes, { kernelShape, pads });
126
return newAttributes;
129
export const parseConvAttributes = (attributes: Record<string, unknown>): ConvAttributes => {
130
const activationAttributes = parseInternalActivationAttributes(attributes);
131
// TODO : Make this generic enough to compute default attributes for multi-dimensional conv
132
const format = attributes.format as 'NHWC' | 'NCHW';
133
const autoPad = ['NOTSET', 'VALID', 'SAME_UPPER', 'SAME_LOWER'][attributes.auto_pad as number];
134
const dilations = attributes.dilations as number[];
135
const group = attributes.group as number;
136
const kernelShape = attributes.kernel_shape as number[];
137
const pads = attributes.pads as number[];
138
const strides = attributes.strides as number[];
139
const wIsConst = (attributes.w_is_const as () => boolean)();
150
...activationAttributes,
151
cacheKey: `${attributes.format};${activationAttributes.activation};`,
156
context: ComputeContext,
157
inputs: readonly TensorView[],
158
attributes: ConvAttributes,
159
squeezeOutputShapeFunction?: (shape: readonly number[]) => number[],
163
// const hasPreluActivationWeights = false; /* TODO: add support for prelu activation weights */
164
const isChannelsLast = attributes.format === 'NHWC';
165
if (attributes.group !== 1) {
166
// NVIDIA GPU with ampere architecture fails with below 2 cases, but we couldn't repro them with any other
167
// GPUs. So just disable vectorize on NVIDIA ampere to ensure always correct outputs.
168
// [webgpu]Conv - conv - vectorize group - B
169
// [webgpu]Conv - conv - vectorize group - D
170
const enableGroupedConvVectorize = !context.adapterInfo.isArchitecture('ampere');
172
enableGroupedConvVectorize &&
174
inputs[1].dims[0] === attributes.group &&
175
inputs[1].dims[1] === 1 &&
176
attributes.dilations[0] === 1 &&
177
attributes.dilations[1] === 1
179
const outputShape = calculateOutputShape(
182
attributes.dilations,
187
const transposedWeight =
188
(context.kernelCustomData.wT as TensorView | undefined) ??
189
context.compute(createTransposeProgramInfo(inputs[1], weightTransposeAttribute), {
191
outputs: [attributes.wIsConst ? -2 : -1],
193
if (attributes.wIsConst && !context.kernelCustomData.wT) {
194
context.kernelCustomData.wT = transposedWeight;
196
const convInputs = [inputs[0], transposedWeight];
197
if (inputs.length === 3) {
198
convInputs.push(inputs[2]);
201
createGroupedConvVectorizeProgramInfo(convInputs, attributes, outputShape, squeezeOutputShapeFunction),
202
{ inputs: convInputs },
205
context.compute(createGroupedConvProgramInfo(inputs, attributes, squeezeOutputShapeFunction));
210
const hasBias = inputs.length === 3;
211
const inputHeight = inputs[0].dims[isChannelsLast ? 1 : 2];
212
const inputWidth = inputs[0].dims[isChannelsLast ? 2 : 3];
213
const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1];
214
const weightHeight = inputs[1].dims[2];
215
const weightWidth = inputs[1].dims[3];
217
const outputShape = calculateOutputShape(
220
attributes.dilations,
225
const outHeight = outputShape[isChannelsLast ? 1 : 2];
226
const outWidth = outputShape[isChannelsLast ? 2 : 3];
227
const outChannels = outputShape[isChannelsLast ? 3 : 1];
231
weightHeight === inputHeight &&
232
weightWidth === inputWidth &&
233
attributes.pads[0] === 0 &&
234
attributes.pads[1] === 0;
237
(weightHeight === 1 &&
239
attributes.dilations[0] === 1 &&
240
attributes.dilations[1] === 1 &&
241
attributes.strides[0] === 1 &&
242
attributes.strides[1] === 1 &&
243
attributes.pads[0] === 0 &&
244
attributes.pads[1] === 0)
247
const batch = outputShape[0];
248
let xReshaped, wReshaped, matmulOutputShape;
249
const matmulInputs = [];
250
if (isChannelsLast) {
251
const transposedWeight =
252
(context.kernelCustomData.wT as TensorView | undefined) ??
253
context.compute(createTransposeProgramInfo(inputs[1], weightTransposeAttribute), {
255
outputs: [attributes.wIsConst ? -2 : -1],
257
if (attributes.wIsConst && !context.kernelCustomData.wT) {
258
context.kernelCustomData.wT = transposedWeight;
261
const sharedDim = inputHeight * inputWidth * inputChannels;
262
xReshaped = inputs[0].reshape([1, batch, sharedDim]);
263
wReshaped = transposedWeight.reshape([1, sharedDim, outChannels]);
264
matmulOutputShape = [1, batch, outChannels];
266
xReshaped = inputs[0].reshape([batch, inputHeight * inputWidth, inputChannels]);
267
wReshaped = transposedWeight.reshape([1, inputChannels, outChannels]);
268
matmulOutputShape = [batch, outHeight * outWidth, outChannels];
270
matmulInputs.push(xReshaped);
271
matmulInputs.push(wReshaped);
273
xReshaped = inputs[0].reshape([batch, inputChannels, inputHeight * inputWidth]);
274
wReshaped = inputs[1].reshape([1, outChannels, inputChannels]);
275
matmulOutputShape = [batch, outChannels, outHeight * outWidth];
276
matmulInputs.push(wReshaped);
277
matmulInputs.push(xReshaped);
280
matmulInputs.push(inputs[2]);
282
const N = matmulOutputShape[2];
283
const K = matmulInputs[0].dims[matmulInputs[0].dims.length - 1];
284
// Tune the threshold.
285
if (N < 8 && K < 8) {
287
createNaiveMatmulProgramInfo(
293
squeezeOutputShapeFunction,
295
{ inputs: matmulInputs },
299
createMatmulProgramInfo(
305
squeezeOutputShapeFunction,
307
{ inputs: matmulInputs },
313
// TODO: implement conv2dWithIm2Col()
315
const sequentialAccessByThreads = /* backend.adapterInfo.isIntel() */ true;
317
// STEP.1: transpose weight
318
const transposedWeight =
319
(context.kernelCustomData.wT as TensorView | undefined) ??
320
context.compute(createTransposeProgramInfo(inputs[1], weightTransposeAttribute), {
322
outputs: [attributes.wIsConst ? -2 : -1],
324
if (attributes.wIsConst && !context.kernelCustomData.wT) {
325
context.kernelCustomData.wT = transposedWeight;
328
// STEP.2: prepare reshaped inputs
329
const convInputs = [inputs[0], transposedWeight];
331
convInputs.push(inputs[2]);
334
// STEP.3: compute matmul
335
const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels;
336
const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth;
337
const dimInner = weightHeight * weightWidth * inputChannels;
339
createConv2DMatMulProgramInfo(
347
sequentialAccessByThreads,
348
squeezeOutputShapeFunction,
350
{ inputs: convInputs },
354
const conv1d = (context: ComputeContext, attributes: ConvAttributes): void => {
355
// extend the input to 2D by adding H dimension
356
const isChannelLast = attributes.format === 'NHWC';
358
context.inputs[0].reshape(
360
? // [N, W, C] -> [N, H=1, W, C]
361
[context.inputs[0].dims[0], 1, context.inputs[0].dims[1], context.inputs[0].dims[2]]
362
: // [N, C, W] -> [N, C, H=1, W]
363
[context.inputs[0].dims[0], context.inputs[0].dims[1], 1, context.inputs[0].dims[2]],
365
//[FILTER_OUT_CHANNEL, FILTER_IN_CHANNEL, kW] -> [FILTER_OUT_CHANNEL, FILTER_IN_CHANNEL, kH=1, kW]
366
context.inputs[1].reshape([context.inputs[1].dims[0], context.inputs[1].dims[1], 1, context.inputs[1].dims[2]]),
368
if (context.inputs.length === 3) {
369
inputs.push(context.inputs[2]);
371
const pads = [0, attributes.pads[0], 0, attributes.pads[1]];
372
const strides = [1].concat(attributes.strides);
373
const dilations = [1].concat(attributes.dilations);
374
const kernelShape = [1].concat(attributes.kernelShape);
375
const adjustedAttributes = getAdjustedConvAttributes(
376
{ ...attributes, pads, strides, dilations, kernelShape },
379
conv2d(context, inputs, adjustedAttributes, (outputShape) =>
380
isChannelLast ? [outputShape[0], outputShape[2], outputShape[3]] : [outputShape[0], outputShape[1], outputShape[3]],
384
const conv3d = (context: ComputeContext, inputs: readonly TensorView[], attributes: ConvAttributes): void => {
385
const format = attributes.format === 'NHWC' ? 'channelsLast' : 'channelsFirst';
386
const adjustedAttributes = getAdjustedConvAttributes(attributes, inputs);
387
const pads = attributes.autoPad === 'NOTSET' ? attributes.pads : attributes.autoPad;
388
const convInfo = computeConv3DInfo(
389
inputs[0].dims as [number, number, number, number, number],
390
inputs[1].dims as [number, number, number, number, number],
391
attributes.strides as number | [number, number, number],
392
attributes.dilations as number | [number, number, number],
393
pads as string | number[],
398
createConv3DNaiveProgramInfo(
402
[convInfo.filterDepth, convInfo.filterHeight, convInfo.filterWidth],
403
[convInfo.padInfo.front, convInfo.padInfo.top, convInfo.padInfo.left],
409
export const conv = (context: ComputeContext, attributes: ConvAttributes): void => {
410
validateInputs(context.inputs, attributes);
411
if (context.inputs[0].dims.length === 3) {
412
conv1d(context, attributes);
413
} else if (context.inputs[0].dims.length === 5) {
414
conv3d(context, context.inputs, attributes);
416
const adjustedAttributes = getAdjustedConvAttributes(attributes, context.inputs);
417
conv2d(context, context.inputs, adjustedAttributes);