1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
5
import { InferenceHandler } from '../../../backend';
6
import { Graph } from '../../../graph';
7
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
8
import { Tensor } from '../../../tensor';
9
import { PoolConvUtil } from '../../../util';
10
import { WebGLInferenceHandler } from '../inference-handler';
12
import { createUnpackedGroupedConvProgramInfoLoader } from './conv-grouped';
13
import { conv2DPacked } from './conv-pack';
14
import { createDotProductProgramInfoLoader } from './dot-product';
15
import { InternalActivationAttributes, parseInternalActivationAttributes } from './fuse-utils';
16
import { createIm2ColProgramInfoLoader } from './im2col';
17
import { createMatmulProgramInfoLoader } from './matmul';
19
export const calculateOutputShape = (
20
inputShape: readonly number[],
21
kernelShape: readonly number[],
22
dilations: readonly number[],
23
adjustPads: readonly number[],
24
strides: readonly number[],
26
const batchSize = inputShape[0];
27
const inputSpatialShape = inputShape.slice(2);
28
const spatialRank = inputSpatialShape.length;
29
const outChannels = kernelShape[0];
30
const kernelSpatialShape = kernelShape.slice(2);
31
const dilatedKernelShape = kernelSpatialShape.map((v, i) => v + (v - 1) * (dilations[i] - 1));
32
const inputSpatialShapeWithPad = inputSpatialShape.map((v, i) => v + adjustPads[i] + adjustPads[i + spatialRank]);
33
const outputSpatialShape = inputSpatialShapeWithPad.map((v, i) =>
34
Math.floor((v - dilatedKernelShape[i] + strides[i]) / strides[i]),
36
const outputShape = [batchSize, outChannels].concat(...outputSpatialShape);
40
export interface ConvAttributes extends InternalActivationAttributes, AttributeWithCacheKey {
41
readonly autoPad: string;
42
readonly dilations: readonly number[];
43
readonly group: number;
44
readonly kernelShape: readonly number[];
45
readonly pads: readonly number[];
46
readonly strides: readonly number[];
49
export const conv: OperatorImplementation<ConvAttributes> = (
50
inferenceHandler: InferenceHandler,
52
attributes: ConvAttributes,
54
validateInputs(inputs, attributes); // currently will fail if not conv2D
55
return conv2d(inferenceHandler, inputs, attributes);
58
const conv2d: OperatorImplementation<ConvAttributes> = (
59
inferenceHandler: WebGLInferenceHandler,
61
attributes: ConvAttributes,
63
const adjustedAttributes = getAdjustedConvAttributes(attributes, inputs);
64
const packMode = inferenceHandler.session.pack;
65
const isPointwise = adjustedAttributes.kernelShape[0] === 1 && adjustedAttributes.kernelShape[1] === 1;
66
if (adjustedAttributes.group > 1) {
67
const result = inferenceHandler.run(
68
createUnpackedGroupedConvProgramInfoLoader(inferenceHandler, inputs, adjustedAttributes),
72
} else if (isPointwise && packMode) {
73
return [conv2DUnpackedPointwise(inferenceHandler, inputs, adjustedAttributes)];
74
} else if (packMode && inputs[0].dims.length === 4 && inputs[0].dims[0] === 1 && !isPointwise) {
75
return [conv2DPacked(inferenceHandler, inputs, adjustedAttributes)];
77
return [conv2DUnpacked(inferenceHandler, inputs, adjustedAttributes)];
81
const conv2DUnpackedPointwise = (
82
inferenceHandler: WebGLInferenceHandler,
83
inputs: readonly Tensor[],
84
attributes: ConvAttributes,
86
const xshape = inputs[0].dims;
87
const kshape = inputs[1].dims;
88
const outputShape = calculateOutputShape(xshape, kshape, attributes.dilations, attributes.pads, attributes.strides);
89
const reshapedX = inferenceHandler.reshapeUnpacked(inputs[0], [xshape[1], xshape[2] * xshape[3]]);
90
const reshapedK = inferenceHandler.reshapeUnpacked(inputs[1], [kshape[0], kshape[1]]);
92
const matmulInputs = inputs.length > 2 ? [reshapedK, reshapedX, inputs[2]] : [reshapedK, reshapedX];
93
const matmulOutput = inferenceHandler.run(createMatmulProgramInfoLoader(matmulInputs, attributes), matmulInputs);
94
return inferenceHandler.reshapeUnpacked(matmulOutput, outputShape);
97
const conv2DUnpacked = (
98
inferenceHandler: WebGLInferenceHandler,
99
inputs: readonly Tensor[],
100
attributes: ConvAttributes,
102
const xshape = inputs[0].dims;
103
const kshape = inputs[1].dims;
104
const outputShape = calculateOutputShape(xshape, kshape, attributes.dilations, attributes.pads, attributes.strides);
105
const xIm2Col = inferenceHandler.run(
106
createIm2ColProgramInfoLoader(inferenceHandler, inputs[0], inputs[1], outputShape, attributes),
110
const dotProductInputs = inputs.length === 3 ? [xIm2Col, inputs[1], inputs[2]] : [xIm2Col, inputs[1]];
111
const output = inferenceHandler.run(
112
createDotProductProgramInfoLoader(inferenceHandler, inputs, outputShape, attributes),
118
const getAdjustedConvAttributes = <T extends ConvAttributes>(attributes: T, inputs: Tensor[]): T => {
119
const kernelShape = attributes.kernelShape.slice();
120
// if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims
121
if (attributes.kernelShape.length === 0) {
122
for (let i = 2; i < inputs[1].dims.length; ++i) {
123
kernelShape.push(inputs[1].dims[i]);
126
const pads = attributes.pads.slice();
127
PoolConvUtil.adjustPadsBasedOnAutoPad(
130
attributes.dilations,
136
// always return a new object so does not modify the original attributes
137
const newAttributes: T = Object.assign({}, attributes);
138
Object.assign(newAttributes, { kernelShape, pads, cacheKey: attributes.cacheKey });
139
return newAttributes;
142
export const parseConvAttributes: OperatorInitialization<ConvAttributes> = (node: Graph.Node): ConvAttributes => {
143
const attributes = node.attributes;
144
const activationAttributes = parseInternalActivationAttributes(attributes);
145
// TODO : Make this generic enough to compute default attributes for multi-dimensional conv
146
const autoPad = attributes.getString('auto_pad', 'NOTSET');
147
const dilations = attributes.getInts('dilations', [1, 1]);
148
const group = attributes.getInt('group', 1);
149
const kernelShape = attributes.getInts('kernel_shape', []);
150
const pads = attributes.getInts('pads', [0, 0, 0, 0]);
151
const strides = attributes.getInts('strides', [1, 1]);
153
return createAttributeWithCacheKey({
160
...activationAttributes,
164
const validateInputs = (inputs: Tensor[], attributes: ConvAttributes): void => {
165
// Refer to the below link for all input checks
166
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
167
if (!inputs || (inputs.length !== 2 && inputs.length !== 3)) {
168
throw new Error('Conv requires 2 or 3 inputs');
171
// TODO : Need to add support for multi-dimensional conv
172
if (inputs[0].dims.length !== 4 || inputs[1].dims.length !== 4) {
173
throw new Error('currently only support 2-dimensional conv');
176
// FILTER_IN_CHANNEL should be equal to DATA_CHANNEL
177
const dataChannel = inputs[0].dims[1];
178
const filterInChannel = inputs[1].dims[1] * attributes.group;
179
if (dataChannel !== filterInChannel) {
180
throw new Error('FILTER_IN_CHANNEL should be equal to DATA_CHANNEL');
183
// if bias is provided it should be 1D and the number of elements should be equal to the number of feature maps
184
if (inputs.length === 3 && (inputs[2].dims.length !== 1 || inputs[1].dims[0] !== inputs[2].dims[0])) {
185
throw new Error('invalid bias');
188
const spatialRank = inputs[0].dims.length - 2;
189
// wrong dilations dimension
190
if (attributes.dilations.length !== spatialRank) {
191
throw new Error(`dilations should be ${spatialRank}D`);
194
// Wrong strides dimension
195
if (attributes.strides.length !== spatialRank) {
196
throw new Error(`strides should be ${spatialRank}D`);
199
// Wrong pads dimension
200
if (attributes.pads.length !== spatialRank * 2) {
201
throw new Error(`pads should be ${spatialRank * 2}D`);
204
// if kernelShape is specified, it's data length must be 2 less than dims length of the weights tensor
205
// (the first 2 dims are batch_size and channels)
206
if (attributes.kernelShape.length !== 0 && attributes.kernelShape.length !== inputs[1].dims.length - 2) {
207
throw new Error('invalid kernel shape');
210
// TODO : Need to add support for float64
211
if (inputs[0].type !== 'float32' || inputs[1].type !== 'float32') {
212
throw new Error('Conv input(X,W) should be float tensor');
215
if (inputs.length === 3 && inputs[2].type !== 'float32') {
216
throw new Error('Conv input(bias) should be float tensor');