onnxruntime

Форк
0
218 строк · 8.7 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
5
import { InferenceHandler } from '../../../backend';
6
import { Graph } from '../../../graph';
7
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
8
import { Tensor } from '../../../tensor';
9
import { PoolConvUtil } from '../../../util';
10
import { WebGLInferenceHandler } from '../inference-handler';
11

12
import { createUnpackedGroupedConvProgramInfoLoader } from './conv-grouped';
13
import { conv2DPacked } from './conv-pack';
14
import { createDotProductProgramInfoLoader } from './dot-product';
15
import { InternalActivationAttributes, parseInternalActivationAttributes } from './fuse-utils';
16
import { createIm2ColProgramInfoLoader } from './im2col';
17
import { createMatmulProgramInfoLoader } from './matmul';
18

19
export const calculateOutputShape = (
20
  inputShape: readonly number[],
21
  kernelShape: readonly number[],
22
  dilations: readonly number[],
23
  adjustPads: readonly number[],
24
  strides: readonly number[],
25
): number[] => {
26
  const batchSize = inputShape[0];
27
  const inputSpatialShape = inputShape.slice(2);
28
  const spatialRank = inputSpatialShape.length;
29
  const outChannels = kernelShape[0];
30
  const kernelSpatialShape = kernelShape.slice(2);
31
  const dilatedKernelShape = kernelSpatialShape.map((v, i) => v + (v - 1) * (dilations[i] - 1));
32
  const inputSpatialShapeWithPad = inputSpatialShape.map((v, i) => v + adjustPads[i] + adjustPads[i + spatialRank]);
33
  const outputSpatialShape = inputSpatialShapeWithPad.map((v, i) =>
34
    Math.floor((v - dilatedKernelShape[i] + strides[i]) / strides[i]),
35
  );
36
  const outputShape = [batchSize, outChannels].concat(...outputSpatialShape);
37
  return outputShape;
38
};
39

40
export interface ConvAttributes extends InternalActivationAttributes, AttributeWithCacheKey {
41
  readonly autoPad: string;
42
  readonly dilations: readonly number[];
43
  readonly group: number;
44
  readonly kernelShape: readonly number[];
45
  readonly pads: readonly number[];
46
  readonly strides: readonly number[];
47
}
48

49
export const conv: OperatorImplementation<ConvAttributes> = (
50
  inferenceHandler: InferenceHandler,
51
  inputs: Tensor[],
52
  attributes: ConvAttributes,
53
): Tensor[] => {
54
  validateInputs(inputs, attributes); // currently will fail if not conv2D
55
  return conv2d(inferenceHandler, inputs, attributes);
56
};
57

58
const conv2d: OperatorImplementation<ConvAttributes> = (
59
  inferenceHandler: WebGLInferenceHandler,
60
  inputs: Tensor[],
61
  attributes: ConvAttributes,
62
): Tensor[] => {
63
  const adjustedAttributes = getAdjustedConvAttributes(attributes, inputs);
64
  const packMode = inferenceHandler.session.pack;
65
  const isPointwise = adjustedAttributes.kernelShape[0] === 1 && adjustedAttributes.kernelShape[1] === 1;
66
  if (adjustedAttributes.group > 1) {
67
    const result = inferenceHandler.run(
68
      createUnpackedGroupedConvProgramInfoLoader(inferenceHandler, inputs, adjustedAttributes),
69
      inputs,
70
    );
71
    return [result];
72
  } else if (isPointwise && packMode) {
73
    return [conv2DUnpackedPointwise(inferenceHandler, inputs, adjustedAttributes)];
74
  } else if (packMode && inputs[0].dims.length === 4 && inputs[0].dims[0] === 1 && !isPointwise) {
75
    return [conv2DPacked(inferenceHandler, inputs, adjustedAttributes)];
76
  } else {
77
    return [conv2DUnpacked(inferenceHandler, inputs, adjustedAttributes)];
78
  }
79
};
80

81
const conv2DUnpackedPointwise = (
82
  inferenceHandler: WebGLInferenceHandler,
83
  inputs: readonly Tensor[],
84
  attributes: ConvAttributes,
85
): Tensor => {
86
  const xshape = inputs[0].dims;
87
  const kshape = inputs[1].dims;
88
  const outputShape = calculateOutputShape(xshape, kshape, attributes.dilations, attributes.pads, attributes.strides);
89
  const reshapedX = inferenceHandler.reshapeUnpacked(inputs[0], [xshape[1], xshape[2] * xshape[3]]);
90
  const reshapedK = inferenceHandler.reshapeUnpacked(inputs[1], [kshape[0], kshape[1]]);
91

92
  const matmulInputs = inputs.length > 2 ? [reshapedK, reshapedX, inputs[2]] : [reshapedK, reshapedX];
93
  const matmulOutput = inferenceHandler.run(createMatmulProgramInfoLoader(matmulInputs, attributes), matmulInputs);
94
  return inferenceHandler.reshapeUnpacked(matmulOutput, outputShape);
95
};
96

97
const conv2DUnpacked = (
98
  inferenceHandler: WebGLInferenceHandler,
99
  inputs: readonly Tensor[],
100
  attributes: ConvAttributes,
101
): Tensor => {
102
  const xshape = inputs[0].dims;
103
  const kshape = inputs[1].dims;
104
  const outputShape = calculateOutputShape(xshape, kshape, attributes.dilations, attributes.pads, attributes.strides);
105
  const xIm2Col = inferenceHandler.run(
106
    createIm2ColProgramInfoLoader(inferenceHandler, inputs[0], inputs[1], outputShape, attributes),
107
    [inputs[0]],
108
  );
109

110
  const dotProductInputs = inputs.length === 3 ? [xIm2Col, inputs[1], inputs[2]] : [xIm2Col, inputs[1]];
111
  const output = inferenceHandler.run(
112
    createDotProductProgramInfoLoader(inferenceHandler, inputs, outputShape, attributes),
113
    dotProductInputs,
114
  );
115
  return output;
116
};
117

118
const getAdjustedConvAttributes = <T extends ConvAttributes>(attributes: T, inputs: Tensor[]): T => {
119
  const kernelShape = attributes.kernelShape.slice();
120
  // if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims
121
  if (attributes.kernelShape.length === 0) {
122
    for (let i = 2; i < inputs[1].dims.length; ++i) {
123
      kernelShape.push(inputs[1].dims[i]);
124
    }
125
  }
126
  const pads = attributes.pads.slice();
127
  PoolConvUtil.adjustPadsBasedOnAutoPad(
128
    inputs[0].dims,
129
    attributes.strides,
130
    attributes.dilations,
131
    kernelShape,
132
    pads,
133
    attributes.autoPad,
134
  );
135

136
  // always return a new object so does not modify the original attributes
137
  const newAttributes: T = Object.assign({}, attributes);
138
  Object.assign(newAttributes, { kernelShape, pads, cacheKey: attributes.cacheKey });
139
  return newAttributes;
140
};
141

142
export const parseConvAttributes: OperatorInitialization<ConvAttributes> = (node: Graph.Node): ConvAttributes => {
143
  const attributes = node.attributes;
144
  const activationAttributes = parseInternalActivationAttributes(attributes);
145
  // TODO : Make this generic enough to compute default attributes for multi-dimensional conv
146
  const autoPad = attributes.getString('auto_pad', 'NOTSET');
147
  const dilations = attributes.getInts('dilations', [1, 1]);
148
  const group = attributes.getInt('group', 1);
149
  const kernelShape = attributes.getInts('kernel_shape', []);
150
  const pads = attributes.getInts('pads', [0, 0, 0, 0]);
151
  const strides = attributes.getInts('strides', [1, 1]);
152

153
  return createAttributeWithCacheKey({
154
    autoPad,
155
    dilations,
156
    group,
157
    kernelShape,
158
    pads,
159
    strides,
160
    ...activationAttributes,
161
  });
162
};
163

164
const validateInputs = (inputs: Tensor[], attributes: ConvAttributes): void => {
165
  // Refer to the below link for all input checks
166
  // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
167
  if (!inputs || (inputs.length !== 2 && inputs.length !== 3)) {
168
    throw new Error('Conv requires 2 or 3 inputs');
169
  }
170

171
  // TODO : Need to add support for multi-dimensional conv
172
  if (inputs[0].dims.length !== 4 || inputs[1].dims.length !== 4) {
173
    throw new Error('currently only support 2-dimensional conv');
174
  }
175

176
  // FILTER_IN_CHANNEL should be equal to DATA_CHANNEL
177
  const dataChannel = inputs[0].dims[1];
178
  const filterInChannel = inputs[1].dims[1] * attributes.group;
179
  if (dataChannel !== filterInChannel) {
180
    throw new Error('FILTER_IN_CHANNEL should be equal to DATA_CHANNEL');
181
  }
182

183
  // if bias is provided it should be 1D and the number of elements should be equal to the number of feature maps
184
  if (inputs.length === 3 && (inputs[2].dims.length !== 1 || inputs[1].dims[0] !== inputs[2].dims[0])) {
185
    throw new Error('invalid bias');
186
  }
187

188
  const spatialRank = inputs[0].dims.length - 2;
189
  // wrong dilations dimension
190
  if (attributes.dilations.length !== spatialRank) {
191
    throw new Error(`dilations should be ${spatialRank}D`);
192
  }
193

194
  // Wrong strides dimension
195
  if (attributes.strides.length !== spatialRank) {
196
    throw new Error(`strides should be ${spatialRank}D`);
197
  }
198

199
  // Wrong pads dimension
200
  if (attributes.pads.length !== spatialRank * 2) {
201
    throw new Error(`pads should be ${spatialRank * 2}D`);
202
  }
203

204
  // if kernelShape is specified, it's data length must be 2 less than dims length of the weights tensor
205
  // (the first 2 dims are batch_size and channels)
206
  if (attributes.kernelShape.length !== 0 && attributes.kernelShape.length !== inputs[1].dims.length - 2) {
207
    throw new Error('invalid kernel shape');
208
  }
209

210
  // TODO : Need to add support for float64
211
  if (inputs[0].type !== 'float32' || inputs[1].type !== 'float32') {
212
    throw new Error('Conv input(X,W) should be float tensor');
213
  }
214

215
  if (inputs.length === 3 && inputs[2].type !== 'float32') {
216
    throw new Error('Conv input(bias) should be float tensor');
217
  }
218
};
219

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.