onnxruntime

Форк
0
513 строк · 18.6 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { env } from 'onnxruntime-common';
5

6
import { DataType } from '../../../wasm-common';
7
import { TensorView } from '../../tensor-view';
8
import { PoolConvUtil, ShapeUtil } from '../../util';
9
import { AttributeWithCacheKey } from '../attribute-with-cache-key';
10
import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types';
11

12
import {
13
  createTensorShapeVariables,
14
  getElementAt,
15
  IndicesHelper,
16
  inputVariable,
17
  outputVariable,
18
  ShaderHelper,
19
  UniformsArrayType,
20
} from './common';
21

22
// TODO: support:
23
// - ceil_mode                 "test_maxpool_2d_ceil"
24
// - storage_order             "test_maxpool_with_argmax_2d_precomputed_strides"
25
// - [MaxPool] dilations       "test_maxpool_2d_dilations"
26
// - [MaxPool] output[1]       "test_maxpool_with_argmax_2d_precomputed_pads"
27

28
const validateInputs = (inputs: readonly TensorView[]): void => {
29
  if (env.webgpu.validateInputContent && (!inputs || inputs.length !== 1)) {
30
    throw new Error('Pool ops requires 1 input.');
31
  }
32
};
33

34
const getAdjustedPoolAttributesAndOutputShape = <AttributeType extends AveragePoolAttributes | MaxPoolAttributes>(
35
  input: TensorView,
36
  attributes: AttributeType,
37
  isGlobalOperator: boolean,
38
): [AttributeType, number[]] => {
39
  const isChannelsLast = attributes.format === 'NHWC';
40
  const inputShapeAsChannelFirst = input.dims.slice();
41
  if (isChannelsLast) {
42
    inputShapeAsChannelFirst.splice(1, 0, inputShapeAsChannelFirst.pop()!); // Move channel to the second position.
43
  }
44
  const hasDilations = Object.hasOwnProperty.call(attributes, 'dilations');
45
  const kernelShape = attributes.kernelShape.slice();
46
  const strides = attributes.strides.slice();
47
  const dilations: number[] = hasDilations ? (attributes as MaxPoolAttributes).dilations.slice() : [];
48
  const pads = attributes.pads.slice();
49
  PoolConvUtil.adjustPoolAttributes(isGlobalOperator, inputShapeAsChannelFirst, kernelShape, strides, dilations, pads);
50

51
  const outputShapeAsChannelFirst = PoolConvUtil.computePoolOutputShape(
52
    isGlobalOperator,
53
    inputShapeAsChannelFirst,
54
    strides,
55
    dilations,
56
    kernelShape,
57
    pads,
58
    attributes.autoPad,
59
  );
60

61
  const newAttributes = Object.assign({}, attributes);
62
  if (hasDilations) {
63
    Object.assign(newAttributes, { kernelShape, strides, pads, dilations, cacheKey: attributes.cacheKey });
64
  } else {
65
    Object.assign(newAttributes, { kernelShape, strides, pads, cacheKey: attributes.cacheKey });
66
  }
67
  const outputShapeAsChannelLast = outputShapeAsChannelFirst.slice();
68
  outputShapeAsChannelLast.push(outputShapeAsChannelLast.splice(1, 1)[0]);
69
  return [newAttributes, isChannelsLast ? outputShapeAsChannelLast : outputShapeAsChannelFirst];
70
};
71

72
const getUniformAndPadInfo = <AttributeType extends AveragePoolAttributes | MaxPoolAttributes>(
73
  outputShape: readonly number[],
74
  attributes: AttributeType,
75
): [ProgramUniform[], UniformsArrayType, boolean, boolean, boolean] => {
76
  const isChannelsLast = attributes.format === 'NHWC';
77
  const outputSize = ShapeUtil.size(outputShape);
78
  const kernelSize = ShapeUtil.size(attributes.kernelShape);
79
  const programUniforms: ProgramUniform[] = [
80
    { type: DataType.uint32, data: outputSize },
81
    { type: DataType.uint32, data: kernelSize },
82
  ];
83
  const uniforms: UniformsArrayType = [
84
    { name: 'outputSize', type: 'u32' },
85
    { name: 'kernelSize', type: 'u32' },
86
  ];
87
  if (attributes.kernelShape.length <= 2) {
88
    const kw = attributes.kernelShape[attributes.kernelShape.length - 1];
89
    const sw = attributes.strides[attributes.strides.length - 1];
90
    const pwStart = attributes.pads[attributes.pads.length / 2 - 1];
91
    const pwEnd = attributes.pads[attributes.pads.length - 1];
92
    const pwStartEndNotZero = !!(pwStart + pwEnd);
93
    programUniforms.push(
94
      { type: DataType.uint32, data: kw },
95
      { type: DataType.uint32, data: sw },
96
      { type: DataType.uint32, data: pwStart },
97
      { type: DataType.uint32, data: pwEnd },
98
    );
99
    uniforms.push(
100
      { name: 'kw', type: 'u32' },
101
      { name: 'sw', type: 'u32' },
102
      { name: 'pwStart', type: 'u32' },
103
      { name: 'pwEnd', type: 'u32' },
104
    );
105

106
    let phStartEndNotZero = false;
107
    if (attributes.kernelShape.length === 2) {
108
      const kh = attributes.kernelShape[attributes.kernelShape.length - 2];
109
      const sh = attributes.strides[attributes.strides.length - 2];
110
      const phStart = attributes.pads[attributes.pads.length / 2 - 2];
111
      const phEnd = attributes.pads[attributes.pads.length - 2];
112
      phStartEndNotZero = !!(phStart + phEnd);
113
      programUniforms.push(
114
        { type: DataType.uint32, data: kh },
115
        { type: DataType.uint32, data: sh },
116
        { type: DataType.uint32, data: phStart },
117
        { type: DataType.uint32, data: phEnd },
118
      );
119

120
      uniforms.push(
121
        { name: 'kh', type: 'u32' },
122
        { name: 'sh', type: 'u32' },
123
        { name: 'phStart', type: 'u32' },
124
        { name: 'phEnd', type: 'u32' },
125
      );
126
    }
127
    return [programUniforms, uniforms, true, pwStartEndNotZero, phStartEndNotZero];
128
  } else {
129
    if (isChannelsLast) {
130
      throw new Error('Pooling with kernelShape.length > 2 is not supported for NHWC format.');
131
    }
132
    const kernelStrides = ShapeUtil.computeStrides(attributes.kernelShape);
133
    programUniforms.push(
134
      { type: DataType.uint32, data: kernelStrides },
135
      { type: DataType.uint32, data: attributes.pads },
136
      { type: DataType.uint32, data: attributes.strides },
137
    );
138
    uniforms.push(
139
      { name: 'kernelStrides', type: 'u32', length: kernelStrides.length },
140
      { name: 'pads', type: 'u32', length: attributes.pads.length },
141
      { name: 'strides', type: 'u32', length: attributes.strides.length },
142
    );
143

144
    const hasPads = attributes.pads.reduce((sum, cur) => sum + cur);
145
    return [programUniforms, uniforms, !!hasPads, false, false];
146
  }
147
};
148

149
const generatePoolingCode = <AttributeType extends AveragePoolAttributes | MaxPoolAttributes>(
150
  shaderHelper: ShaderHelper,
151
  x: IndicesHelper,
152
  rank: number,
153
  outputShapeRank: number,
154
  attributes: AttributeType,
155
  op1: string,
156
  op2: string,
157
  start: number,
158
  uniforms: UniformsArrayType,
159
  hasPads: boolean,
160
  pwStartEndNotZero: boolean,
161
  phStartEndNotZero: boolean,
162
): string => {
163
  const isChannelsLast = attributes.format === 'NHWC';
164
  const dataType = x.type.value;
165
  const output = outputVariable('output', x.type.tensor, outputShapeRank);
166

167
  if (attributes.kernelShape.length <= 2) {
168
    let codeW = '';
169
    let codeH = '';
170
    let codeHEnd = '';
171
    const dimIdxW = rank - (isChannelsLast ? 2 : 1);
172
    if (pwStartEndNotZero) {
173
      codeW = `
174
                for (var i: u32 = 0u; i < uniforms.kw; i++) {
175
                  xIndices[${dimIdxW}] = indices[${dimIdxW}] * uniforms.sw - uniforms.pwStart + i;
176
                  if (xIndices[${dimIdxW}] < 0 || xIndices[${dimIdxW}]
177
                      >= uniforms.x_shape[${dimIdxW}]) {
178
                    pad++;
179
                    continue;
180
                  }
181
                  let x_val = x[${x.indicesToOffset('xIndices')}];
182
                  ${op1}
183
                }`;
184
    } else {
185
      codeW = `
186
                for (var i: u32 = 0u; i < uniforms.kw; i++) {
187
                  xIndices[${dimIdxW}] = indices[${dimIdxW}] * uniforms.sw - uniforms.pwStart + i;
188
                  let x_val = x[${x.indicesToOffset('xIndices')}];
189
                  ${op1}
190
                }`;
191
    }
192

193
    if (attributes.kernelShape.length === 2) {
194
      const dimIdxH = rank - (isChannelsLast ? 3 : 2);
195
      if (phStartEndNotZero) {
196
        codeH = `
197
                for (var j: u32 = 0u; j < uniforms.kh; j++) {
198
                  xIndices[${dimIdxH}] = indices[${dimIdxH}] * uniforms.sh - uniforms.phStart + j;
199
                  if (xIndices[${dimIdxH}] < 0 || xIndices[${dimIdxH}] >= uniforms.x_shape[${dimIdxH}]) {
200
                    pad += i32(uniforms.kw);
201
                    continue;
202
                  }
203
              `;
204
      } else {
205
        codeH = `
206
                for (var j: u32 = 0u; j < uniforms.kh; j++) {
207
                  xIndices[${dimIdxH}] = indices[${dimIdxH}] * uniforms.sh - uniforms.phStart + j;
208
                `;
209
      }
210
      codeHEnd = `
211
              }
212
            `;
213
    }
214

215
    const poolingCode = `
216
            ${shaderHelper.registerUniforms(uniforms).declareVariables(x, output)}
217

218
            ${shaderHelper.mainStart()}
219
              ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
220

221
              let indices = ${output.offsetToIndices('global_idx')};
222
              var xIndices = ${output.offsetToIndices('global_idx')};
223

224
              var value = ${dataType}(${start});
225
              var pad = 0;
226
              ${codeH}
227
              ${codeW}
228
              ${codeHEnd}
229
              ${op2}
230

231
              output[global_idx] = value;
232
            }`;
233
    return poolingCode;
234
  } else {
235
    if (isChannelsLast) {
236
      throw new Error('Pooling with kernelShape.length > 2 is not supported for NHWC format.');
237
    }
238
    const stridesRank = attributes.kernelShape.length;
239
    const padsRank = attributes.pads.length;
240
    let padCode = '';
241
    if (hasPads) {
242
      padCode = `
243
                if (xIndices[j] >= uniforms.x_shape[j]) {
244
                  pad++;
245
                  isPad = true;
246
                  break;
247
                }
248
              }
249
              if (!isPad) {
250
                let x_val = x[${x.indicesToOffset('xIndices')}];
251
                ${op1}
252
              }`;
253
    } else {
254
      padCode = `
255
              }
256
              let x_val = x[${x.indicesToOffset('xIndices')}];
257
              ${op1}
258
            `;
259
    }
260
    const poolingCode = `
261
            ${shaderHelper.registerUniforms(uniforms).declareVariables(x, output)}
262

263
            ${shaderHelper.mainStart()}
264
              ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
265
              let indices = ${output.offsetToIndices('global_idx')};
266
              var xIndices = ${output.offsetToIndices('global_idx')};
267

268
              var offsets: array<u32, ${stridesRank}>;
269

270
              var value = ${dataType}(${start});
271
              var pad = 0;
272
              var isPad = false;
273

274
              for (var i: u32 = 0u; i < uniforms.kernelSize; i++) {
275
                var offset = i;
276
                for (var j = 0u; j < ${stridesRank - 1}u; j++) {
277
                  offsets[j] = offset / ${getElementAt('uniforms.kernelStrides', 'j', stridesRank)};
278
                  offset -= offsets[j] * ${getElementAt('uniforms.kernelStrides', 'j', stridesRank)};
279
                }
280
                offsets[${stridesRank - 1}] = offset;
281

282
                isPad = false;
283
                for (var j = ${rank - stridesRank}u; j < ${rank}u; j++) {
284
                  xIndices[j] = indices[j] * ${getElementAt(
285
                    'uniforms.strides',
286
                    `j - ${rank - stridesRank}u`,
287
                    stridesRank,
288
                  )}
289
                    + offsets[j - ${rank - stridesRank}u] - ${getElementAt('uniforms.pads', 'j - 2u', padsRank)};
290
                  ${padCode}
291
              }
292
              ${op2}
293

294
              output[global_idx] = value;
295
            }`;
296
    return poolingCode;
297
  }
298
};
299

300
export interface FormatAttributes {
301
  readonly format: 'NHWC' | 'NCHW';
302
}
303

304
export interface PoolCommonAttributes extends FormatAttributes {
305
  readonly autoPad: string;
306
  readonly ceilMode: number;
307
  readonly kernelShape: readonly number[];
308
  readonly strides: readonly number[];
309
  readonly pads: readonly number[];
310
}
311

312
const createShaderKeyFromAttributes = (attributes: PoolCommonAttributes): string =>
313
  `${attributes.format};${attributes.ceilMode};${attributes.autoPad};${attributes.kernelShape.length}`;
314

315
const createAveragePoolShaderKeyFromAttributes = (attributes: AveragePoolAttributes): string =>
316
  `${createShaderKeyFromAttributes(attributes)};${attributes.countIncludePad}`;
317

318
const createMaxPoolShaderKeyFromAttributes = (attributes: MaxPoolAttributes): string =>
319
  `${createShaderKeyFromAttributes(attributes)};${attributes.storageOrder};${attributes.dilations}`;
320

321
const parsePoolCommonAttributes = (attributes: Record<string, unknown>): PoolCommonAttributes => ({
322
  format: attributes.format as FormatAttributes['format'],
323
  autoPad: ['NOTSET', 'VALID', 'SAME_UPPER', 'SAME_LOWER'][attributes.auto_pad as number],
324
  ceilMode: attributes.ceil_mode as number,
325
  kernelShape: attributes.kernel_shape as [number, number],
326
  strides: attributes.strides as [number, number],
327
  pads: attributes.pads as [number, number, number, number],
328
});
329

330
export interface AveragePoolAttributes extends PoolCommonAttributes, AttributeWithCacheKey {
331
  readonly countIncludePad: boolean;
332
}
333

334
const createAveragePoolProgramInfo = (
335
  name: string,
336
  input: TensorView,
337
  isGlobalOperator: boolean,
338
  attributes: AveragePoolAttributes,
339
): ProgramInfo => {
340
  const [adjustedAttributes, outputShape] = getAdjustedPoolAttributesAndOutputShape(
341
    input,
342
    attributes,
343
    isGlobalOperator,
344
  );
345
  const x = inputVariable('x', input.dataType, input.dims.length);
346
  const dataType = x.type.value;
347

348
  const op1 = 'value += x_val;';
349
  let op2 = '';
350
  if (adjustedAttributes.countIncludePad) {
351
    op2 += `value /= ${dataType}(uniforms.kernelSize);`;
352
  } else {
353
    op2 += `value /= ${dataType}(i32(uniforms.kernelSize) - pad);`;
354
  }
355
  const [programUniforms, uniforms, hasPads, pwStartEndNotZero, phStartEndNotZero] = getUniformAndPadInfo(
356
    outputShape,
357
    adjustedAttributes,
358
  );
359
  programUniforms.push(...createTensorShapeVariables(input.dims, outputShape));
360
  const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
361
  return {
362
    name,
363
    shaderCache: {
364
      hint: `${attributes.cacheKey};${hasPads};${pwStartEndNotZero};${phStartEndNotZero}`,
365
      inputDependencies,
366
    },
367
    getRunData: () => ({
368
      outputs: [{ dims: outputShape, dataType: input.dataType }],
369
      dispatchGroup: { x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */) },
370
      programUniforms,
371
    }),
372
    getShaderSource: (shaderHelper) =>
373
      generatePoolingCode(
374
        shaderHelper,
375
        x,
376
        input.dims.length,
377
        outputShape.length,
378
        adjustedAttributes,
379
        op1,
380
        op2,
381
        0.0,
382
        uniforms,
383
        hasPads,
384
        pwStartEndNotZero,
385
        phStartEndNotZero,
386
      ),
387
  };
388
};
389

390
export const parseAveragePoolAttributes = (attributes: Record<string, unknown>): AveragePoolAttributes => {
391
  const countIncludePad = (attributes.count_include_pad as number) === 0 ? false : true;
392

393
  const attr = parsePoolCommonAttributes(attributes);
394
  // TODO: support attribute 'ceil_mode'
395
  if (attr.ceilMode !== 0) {
396
    throw new Error('using ceil() in shape computation is not yet supported for AveragePool');
397
  }
398
  const averagePoolAttributes = { countIncludePad, ...attr, cacheKey: '' };
399
  return { ...averagePoolAttributes, cacheKey: createAveragePoolShaderKeyFromAttributes(averagePoolAttributes) };
400
};
401

402
export const averagePool = (context: ComputeContext, attributes: AveragePoolAttributes): void => {
403
  validateInputs(context.inputs);
404
  context.compute(createAveragePoolProgramInfo('AveragePool', context.inputs[0], false, attributes));
405
};
406

407
const globalPoolAttributes = {
408
  autoPad: '',
409
  ceilMode: 0,
410
  countIncludePad: false,
411
  kernelShape: [],
412
  strides: [],
413
  pads: [],
414
  storageOrder: 0,
415
  dilations: [],
416
};
417

418
export const parseGlobalAveragePoolAttributes = (attributes: Record<string, unknown>): AveragePoolAttributes => {
419
  const format = attributes.format as FormatAttributes['format'];
420
  return { format, ...globalPoolAttributes, cacheKey: format };
421
};
422

423
export const globalAveragePool = (context: ComputeContext, attributes: AveragePoolAttributes): void => {
424
  validateInputs(context.inputs);
425
  context.compute(createAveragePoolProgramInfo('GlobalAveragePool', context.inputs[0], true, attributes));
426
};
427

428
export interface MaxPoolAttributes extends PoolCommonAttributes, AttributeWithCacheKey {
429
  readonly storageOrder: number;
430
  readonly dilations: number[];
431
}
432

433
const createMaxPoolProgramInfo = (
434
  name: string,
435
  input: TensorView,
436
  isGlobalOperator: boolean,
437
  attributes: MaxPoolAttributes,
438
): ProgramInfo => {
439
  const [adjustedAttributes, outputShape] = getAdjustedPoolAttributesAndOutputShape(
440
    input,
441
    attributes,
442
    isGlobalOperator,
443
  );
444
  const op1 = `
445
      value = max(x_val, value);
446
    `;
447
  const op2 = '';
448
  const x = inputVariable('x', input.dataType, input.dims.length);
449
  const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
450
  const [programUniforms, uniforms, hasPads, pwStartEndNotZero, phStartEndNotZero] = getUniformAndPadInfo(
451
    outputShape,
452
    adjustedAttributes,
453
  );
454
  programUniforms.push(...createTensorShapeVariables(input.dims, outputShape));
455
  return {
456
    name,
457
    shaderCache: {
458
      hint: `${attributes.cacheKey};${hasPads};${pwStartEndNotZero};${phStartEndNotZero}`,
459
      inputDependencies,
460
    },
461
    getRunData: () => ({
462
      outputs: [{ dims: outputShape, dataType: input.dataType }],
463
      dispatchGroup: { x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */) },
464
      programUniforms,
465
    }),
466
    getShaderSource: (shaderHelper) =>
467
      generatePoolingCode(
468
        shaderHelper,
469
        x,
470
        input.dims.length,
471
        outputShape.length,
472
        adjustedAttributes,
473
        op1,
474
        op2,
475
        input.dataType === DataType.float16 ? -65504 : -1e5,
476
        uniforms,
477
        hasPads,
478
        pwStartEndNotZero,
479
        phStartEndNotZero,
480
      ),
481
  };
482
};
483

484
export const maxPool = (context: ComputeContext, attributes: MaxPoolAttributes): void => {
485
  validateInputs(context.inputs);
486
  context.compute(createMaxPoolProgramInfo('MaxPool', context.inputs[0], false, attributes));
487
};
488

489
export const parseMaxPoolAttributes = (attributes: Record<string, unknown>): MaxPoolAttributes => {
490
  const storageOrder = attributes.storage_order as number;
491
  const dilations = attributes.dilations as [number, number];
492

493
  const attr = parsePoolCommonAttributes(attributes);
494
  // TODO: support attribute 'ceil_mode' and 'storage_order'
495
  if (storageOrder !== 0) {
496
    throw new Error('column major storage order is not yet supported for MaxPool');
497
  }
498
  if (attr.ceilMode !== 0) {
499
    throw new Error('using ceil() in shape computation is not yet supported for MaxPool');
500
  }
501
  const maxPoolAttributes = { storageOrder, dilations, ...attr, cacheKey: '' };
502
  return { ...maxPoolAttributes, cacheKey: createMaxPoolShaderKeyFromAttributes(maxPoolAttributes) };
503
};
504

505
export const parseGlobalMaxPoolAttributes = (attributes: Record<string, unknown>): MaxPoolAttributes => {
506
  const format = attributes.format as FormatAttributes['format'];
507
  return { format, ...globalPoolAttributes, cacheKey: format };
508
};
509

510
export const globalMaxPool = (context: ComputeContext, attributes: MaxPoolAttributes): void => {
511
  validateInputs(context.inputs);
512
  context.compute(createMaxPoolProgramInfo('GlobalMaxPool', context.inputs[0], true, attributes));
513
};
514

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.