onnxruntime

Форк
0
419 строк · 15.4 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { TensorView } from '../../tensor-view';
5
import { PoolConvUtil } from '../../util';
6
import { AttributeWithCacheKey } from '../attribute-with-cache-key';
7
import { ComputeContext } from '../types';
8

9
import { createConv2DMatMulProgramInfo } from './3rd-party/conv2d_mm_webgpu';
10
import { computeConv3DInfo, createConv3DNaiveProgramInfo } from './3rd-party/conv3d_naive_webgpu';
11
import { createMatmulProgramInfo } from './3rd-party/matmul_packed_webgpu';
12
import { createGroupedConvProgramInfo, createGroupedConvVectorizeProgramInfo } from './conv-grouped';
13
import { InternalActivationAttributes, parseInternalActivationAttributes } from './fuse-utils';
14
import { createNaiveMatmulProgramInfo } from './matmul';
15
import { createTransposeProgramInfo } from './transpose';
16

17
export const calculateOutputShape = (
18
  inputShape: readonly number[],
19
  kernelShape: readonly number[],
20
  dilations: readonly number[],
21
  adjustPads: readonly number[],
22
  strides: readonly number[],
23
  isChannelLast: boolean,
24
): number[] => {
25
  const batchSize = inputShape[0];
26
  const inputSpatialShape = inputShape.slice(isChannelLast ? 1 : 2, isChannelLast ? 3 : 4);
27
  const spatialRank = inputSpatialShape.length;
28
  const outChannels = kernelShape[0];
29
  const kernelSpatialShape = kernelShape.slice(2);
30
  const dilatedKernelShape = kernelSpatialShape.map((v, i) => v + (v - 1) * (dilations[i] - 1));
31
  const inputSpatialShapeWithPad = inputSpatialShape.map((v, i) => v + adjustPads[i] + adjustPads[i + spatialRank]);
32
  const outputShape = inputSpatialShapeWithPad.map((v, i) =>
33
    Math.floor((v - dilatedKernelShape[i] + strides[i]) / strides[i]),
34
  );
35
  outputShape.splice(0, 0, batchSize);
36
  outputShape.splice(isChannelLast ? 3 : 1, 0, outChannels);
37
  return outputShape;
38
};
39

40
export interface ConvAttributes extends InternalActivationAttributes, AttributeWithCacheKey {
41
  readonly autoPad: string;
42
  readonly dilations: readonly number[];
43
  readonly format: 'NHWC' | 'NCHW';
44
  readonly group: number;
45
  readonly kernelShape: readonly number[];
46
  readonly pads: readonly number[];
47
  readonly strides: readonly number[];
48
  readonly wIsConst: boolean;
49
}
50

51
// for transposing weight tensor from [M, C/group, KH, KW] to [KH, KW, C/group, M]
52
const weightTransposeAttribute = [2, 3, 1, 0];
53

54
const validateInputs = (inputs: readonly TensorView[], attributes: ConvAttributes): void => {
55
  // Refer to the below link for all input checks
56
  // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Conv
57
  if (!inputs || (inputs.length !== 2 && inputs.length !== 3)) {
58
    throw new Error('Conv requires 2 or 3 inputs');
59
  }
60

61
  if (inputs[0].dims.length > 5) {
62
    throw new Error('greater than 5D is not supported');
63
  }
64

65
  if (inputs[0].dims.length !== inputs[1].dims.length) {
66
    throw new Error('filter does not have same dimension as input');
67
  }
68

69
  // FILTER_IN_CHANNEL should be equal to DATA_CHANNEL
70
  const dataChannel = inputs[0].dims[attributes.format === 'NHWC' ? inputs[0].dims.length - 1 : 1];
71
  const filterInChannel = inputs[1].dims[1] * attributes.group;
72
  if (dataChannel !== filterInChannel) {
73
    throw new Error('FILTER_IN_CHANNEL should be equal to DATA_CHANNEL');
74
  }
75

76
  // if bias is provided it should be 1D and the number of elements should be equal to the number of feature maps
77
  if (inputs.length === 3 && (inputs[2].dims.length !== 1 || inputs[1].dims[0] !== inputs[2].dims[0])) {
78
    throw new Error('invalid bias');
79
  }
80

81
  const spatialRank = inputs[0].dims.length - 2;
82
  // wrong dilations dimension
83
  if (attributes.dilations.length !== spatialRank) {
84
    throw new Error(`dilations should be ${spatialRank}D`);
85
  }
86

87
  // Wrong strides dimension
88
  if (attributes.strides.length !== spatialRank) {
89
    throw new Error(`strides should be ${spatialRank}D`);
90
  }
91

92
  // Wrong pads dimension
93
  if (attributes.pads.length !== spatialRank * 2) {
94
    throw new Error(`pads should be ${spatialRank * 2}D`);
95
  }
96

97
  // if kernelShape is specified, it's data length must be 2 less than dims length of the weights tensor
98
  // (the first 2 dims are batch_size and channels)
99
  if (attributes.kernelShape.length !== 0 && attributes.kernelShape.length !== inputs[1].dims.length - 2) {
100
    throw new Error('invalid kernel shape');
101
  }
102
};
103

104
const getAdjustedConvAttributes = <T extends ConvAttributes>(attributes: T, inputs: readonly TensorView[]): T => {
105
  const kernelShape = attributes.kernelShape.slice();
106
  // if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims
107
  for (let i = 2; i < inputs[1].dims.length; ++i) {
108
    if (kernelShape[i - 2] === 0) {
109
      kernelShape[i - 2] = inputs[1].dims[i];
110
    }
111
  }
112
  const pads = attributes.pads.slice();
113
  PoolConvUtil.adjustPadsBasedOnAutoPad(
114
    inputs[0].dims,
115
    attributes.strides,
116
    attributes.dilations,
117
    kernelShape,
118
    pads,
119
    attributes.format === 'NHWC',
120
    attributes.autoPad,
121
  );
122

123
  // always return a new object so does not modify the original attributes
124
  const newAttributes: T = Object.assign({}, attributes);
125
  Object.assign(newAttributes, { kernelShape, pads });
126
  return newAttributes;
127
};
128

129
export const parseConvAttributes = (attributes: Record<string, unknown>): ConvAttributes => {
130
  const activationAttributes = parseInternalActivationAttributes(attributes);
131
  // TODO : Make this generic enough to compute default attributes for multi-dimensional conv
132
  const format = attributes.format as 'NHWC' | 'NCHW';
133
  const autoPad = ['NOTSET', 'VALID', 'SAME_UPPER', 'SAME_LOWER'][attributes.auto_pad as number];
134
  const dilations = attributes.dilations as number[];
135
  const group = attributes.group as number;
136
  const kernelShape = attributes.kernel_shape as number[];
137
  const pads = attributes.pads as number[];
138
  const strides = attributes.strides as number[];
139
  const wIsConst = (attributes.w_is_const as () => boolean)();
140

141
  return {
142
    autoPad,
143
    format,
144
    dilations,
145
    group,
146
    kernelShape,
147
    pads,
148
    strides,
149
    wIsConst,
150
    ...activationAttributes,
151
    cacheKey: `${attributes.format};${activationAttributes.activation};`,
152
  };
153
};
154

155
const conv2d = (
156
  context: ComputeContext,
157
  inputs: readonly TensorView[],
158
  attributes: ConvAttributes,
159
  squeezeOutputShapeFunction?: (shape: readonly number[]) => number[],
160
): void => {
161
  // check attributes
162

163
  // const hasPreluActivationWeights = false; /* TODO: add support for prelu activation weights */
164
  const isChannelsLast = attributes.format === 'NHWC';
165
  if (attributes.group !== 1) {
166
    // NVIDIA GPU with ampere architecture fails with below 2 cases, but we couldn't repro them with any other
167
    // GPUs. So just disable vectorize on NVIDIA ampere to ensure always correct outputs.
168
    // [webgpu]Conv - conv - vectorize group - B
169
    // [webgpu]Conv - conv - vectorize group - D
170
    const enableGroupedConvVectorize = !context.adapterInfo.isArchitecture('ampere');
171
    if (
172
      enableGroupedConvVectorize &&
173
      isChannelsLast &&
174
      inputs[1].dims[0] === attributes.group &&
175
      inputs[1].dims[1] === 1 &&
176
      attributes.dilations[0] === 1 &&
177
      attributes.dilations[1] === 1
178
    ) {
179
      const outputShape = calculateOutputShape(
180
        inputs[0].dims,
181
        inputs[1].dims,
182
        attributes.dilations,
183
        attributes.pads,
184
        attributes.strides,
185
        isChannelsLast,
186
      );
187
      const transposedWeight =
188
        (context.kernelCustomData.wT as TensorView | undefined) ??
189
        context.compute(createTransposeProgramInfo(inputs[1], weightTransposeAttribute), {
190
          inputs: [1],
191
          outputs: [attributes.wIsConst ? -2 : -1],
192
        })[0];
193
      if (attributes.wIsConst && !context.kernelCustomData.wT) {
194
        context.kernelCustomData.wT = transposedWeight;
195
      }
196
      const convInputs = [inputs[0], transposedWeight];
197
      if (inputs.length === 3) {
198
        convInputs.push(inputs[2]);
199
      }
200
      context.compute(
201
        createGroupedConvVectorizeProgramInfo(convInputs, attributes, outputShape, squeezeOutputShapeFunction),
202
        { inputs: convInputs },
203
      );
204
    } else {
205
      context.compute(createGroupedConvProgramInfo(inputs, attributes, squeezeOutputShapeFunction));
206
    }
207
    return;
208
  }
209

210
  const hasBias = inputs.length === 3;
211
  const inputHeight = inputs[0].dims[isChannelsLast ? 1 : 2];
212
  const inputWidth = inputs[0].dims[isChannelsLast ? 2 : 3];
213
  const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1];
214
  const weightHeight = inputs[1].dims[2];
215
  const weightWidth = inputs[1].dims[3];
216

217
  const outputShape = calculateOutputShape(
218
    inputs[0].dims,
219
    inputs[1].dims,
220
    attributes.dilations,
221
    attributes.pads,
222
    attributes.strides,
223
    isChannelsLast,
224
  );
225
  const outHeight = outputShape[isChannelsLast ? 1 : 2];
226
  const outWidth = outputShape[isChannelsLast ? 2 : 3];
227
  const outChannels = outputShape[isChannelsLast ? 3 : 1];
228

229
  const sameSize =
230
    isChannelsLast &&
231
    weightHeight === inputHeight &&
232
    weightWidth === inputWidth &&
233
    attributes.pads[0] === 0 &&
234
    attributes.pads[1] === 0;
235
  if (
236
    sameSize ||
237
    (weightHeight === 1 &&
238
      weightWidth === 1 &&
239
      attributes.dilations[0] === 1 &&
240
      attributes.dilations[1] === 1 &&
241
      attributes.strides[0] === 1 &&
242
      attributes.strides[1] === 1 &&
243
      attributes.pads[0] === 0 &&
244
      attributes.pads[1] === 0)
245
  ) {
246
    // conv2dByMatMul
247
    const batch = outputShape[0];
248
    let xReshaped, wReshaped, matmulOutputShape;
249
    const matmulInputs = [];
250
    if (isChannelsLast) {
251
      const transposedWeight =
252
        (context.kernelCustomData.wT as TensorView | undefined) ??
253
        context.compute(createTransposeProgramInfo(inputs[1], weightTransposeAttribute), {
254
          inputs: [1],
255
          outputs: [attributes.wIsConst ? -2 : -1],
256
        })[0];
257
      if (attributes.wIsConst && !context.kernelCustomData.wT) {
258
        context.kernelCustomData.wT = transposedWeight;
259
      }
260
      if (sameSize) {
261
        const sharedDim = inputHeight * inputWidth * inputChannels;
262
        xReshaped = inputs[0].reshape([1, batch, sharedDim]);
263
        wReshaped = transposedWeight.reshape([1, sharedDim, outChannels]);
264
        matmulOutputShape = [1, batch, outChannels];
265
      } else {
266
        xReshaped = inputs[0].reshape([batch, inputHeight * inputWidth, inputChannels]);
267
        wReshaped = transposedWeight.reshape([1, inputChannels, outChannels]);
268
        matmulOutputShape = [batch, outHeight * outWidth, outChannels];
269
      }
270
      matmulInputs.push(xReshaped);
271
      matmulInputs.push(wReshaped);
272
    } else {
273
      xReshaped = inputs[0].reshape([batch, inputChannels, inputHeight * inputWidth]);
274
      wReshaped = inputs[1].reshape([1, outChannels, inputChannels]);
275
      matmulOutputShape = [batch, outChannels, outHeight * outWidth];
276
      matmulInputs.push(wReshaped);
277
      matmulInputs.push(xReshaped);
278
    }
279
    if (hasBias) {
280
      matmulInputs.push(inputs[2]);
281
    }
282
    const N = matmulOutputShape[2];
283
    const K = matmulInputs[0].dims[matmulInputs[0].dims.length - 1];
284
    // Tune the threshold.
285
    if (N < 8 && K < 8) {
286
      context.compute(
287
        createNaiveMatmulProgramInfo(
288
          matmulInputs,
289
          attributes,
290
          outputShape,
291
          matmulOutputShape,
292
          isChannelsLast,
293
          squeezeOutputShapeFunction,
294
        ),
295
        { inputs: matmulInputs },
296
      );
297
    } else {
298
      context.compute(
299
        createMatmulProgramInfo(
300
          matmulInputs,
301
          attributes,
302
          outputShape,
303
          matmulOutputShape,
304
          isChannelsLast,
305
          squeezeOutputShapeFunction,
306
        ),
307
        { inputs: matmulInputs },
308
      );
309
    }
310
    return;
311
  }
312

313
  // TODO: implement conv2dWithIm2Col()
314

315
  const sequentialAccessByThreads = /* backend.adapterInfo.isIntel() */ true;
316

317
  // STEP.1: transpose weight
318
  const transposedWeight =
319
    (context.kernelCustomData.wT as TensorView | undefined) ??
320
    context.compute(createTransposeProgramInfo(inputs[1], weightTransposeAttribute), {
321
      inputs: [1],
322
      outputs: [attributes.wIsConst ? -2 : -1],
323
    })[0];
324
  if (attributes.wIsConst && !context.kernelCustomData.wT) {
325
    context.kernelCustomData.wT = transposedWeight;
326
  }
327

328
  // STEP.2: prepare reshaped inputs
329
  const convInputs = [inputs[0], transposedWeight];
330
  if (hasBias) {
331
    convInputs.push(inputs[2]);
332
  }
333

334
  // STEP.3: compute matmul
335
  const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels;
336
  const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth;
337
  const dimInner = weightHeight * weightWidth * inputChannels;
338
  context.compute(
339
    createConv2DMatMulProgramInfo(
340
      convInputs,
341
      attributes,
342
      outputShape,
343
      dimAOuter,
344
      dimBOuter,
345
      dimInner,
346
      hasBias,
347
      sequentialAccessByThreads,
348
      squeezeOutputShapeFunction,
349
    ),
350
    { inputs: convInputs },
351
  );
352
};
353

354
const conv1d = (context: ComputeContext, attributes: ConvAttributes): void => {
355
  // extend the input to 2D by adding H dimension
356
  const isChannelLast = attributes.format === 'NHWC';
357
  const inputs = [
358
    context.inputs[0].reshape(
359
      isChannelLast
360
        ? // [N, W, C] -> [N, H=1, W, C]
361
          [context.inputs[0].dims[0], 1, context.inputs[0].dims[1], context.inputs[0].dims[2]]
362
        : // [N, C, W] -> [N, C, H=1, W]
363
          [context.inputs[0].dims[0], context.inputs[0].dims[1], 1, context.inputs[0].dims[2]],
364
    ),
365
    //[FILTER_OUT_CHANNEL, FILTER_IN_CHANNEL, kW] -> [FILTER_OUT_CHANNEL, FILTER_IN_CHANNEL, kH=1, kW]
366
    context.inputs[1].reshape([context.inputs[1].dims[0], context.inputs[1].dims[1], 1, context.inputs[1].dims[2]]),
367
  ];
368
  if (context.inputs.length === 3) {
369
    inputs.push(context.inputs[2]);
370
  }
371
  const pads = [0, attributes.pads[0], 0, attributes.pads[1]];
372
  const strides = [1].concat(attributes.strides);
373
  const dilations = [1].concat(attributes.dilations);
374
  const kernelShape = [1].concat(attributes.kernelShape);
375
  const adjustedAttributes = getAdjustedConvAttributes(
376
    { ...attributes, pads, strides, dilations, kernelShape },
377
    inputs,
378
  );
379
  conv2d(context, inputs, adjustedAttributes, (outputShape) =>
380
    isChannelLast ? [outputShape[0], outputShape[2], outputShape[3]] : [outputShape[0], outputShape[1], outputShape[3]],
381
  );
382
};
383

384
const conv3d = (context: ComputeContext, inputs: readonly TensorView[], attributes: ConvAttributes): void => {
385
  const format = attributes.format === 'NHWC' ? 'channelsLast' : 'channelsFirst';
386
  const adjustedAttributes = getAdjustedConvAttributes(attributes, inputs);
387
  const pads = attributes.autoPad === 'NOTSET' ? attributes.pads : attributes.autoPad;
388
  const convInfo = computeConv3DInfo(
389
    inputs[0].dims as [number, number, number, number, number],
390
    inputs[1].dims as [number, number, number, number, number],
391
    attributes.strides as number | [number, number, number],
392
    attributes.dilations as number | [number, number, number],
393
    pads as string | number[],
394
    false,
395
    format,
396
  );
397
  context.compute(
398
    createConv3DNaiveProgramInfo(
399
      inputs,
400
      adjustedAttributes,
401
      convInfo.outShape,
402
      [convInfo.filterDepth, convInfo.filterHeight, convInfo.filterWidth],
403
      [convInfo.padInfo.front, convInfo.padInfo.top, convInfo.padInfo.left],
404
      format,
405
    ),
406
  );
407
};
408

409
export const conv = (context: ComputeContext, attributes: ConvAttributes): void => {
410
  validateInputs(context.inputs, attributes);
411
  if (context.inputs[0].dims.length === 3) {
412
    conv1d(context, attributes);
413
  } else if (context.inputs[0].dims.length === 5) {
414
    conv3d(context, context.inputs, attributes);
415
  } else {
416
    const adjustedAttributes = getAdjustedConvAttributes(attributes, context.inputs);
417
    conv2d(context, context.inputs, adjustedAttributes);
418
  }
419
};
420

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.