1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { env } from 'onnxruntime-common';
6
import { DataType } from '../../../wasm-common';
7
import { TensorView } from '../../tensor-view';
8
import { PoolConvUtil, ShapeUtil } from '../../util';
9
import { AttributeWithCacheKey } from '../attribute-with-cache-key';
10
import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types';
13
createTensorShapeVariables,
23
// - ceil_mode "test_maxpool_2d_ceil"
24
// - storage_order "test_maxpool_with_argmax_2d_precomputed_strides"
25
// - [MaxPool] dilations "test_maxpool_2d_dilations"
26
// - [MaxPool] output[1] "test_maxpool_with_argmax_2d_precomputed_pads"
28
const validateInputs = (inputs: readonly TensorView[]): void => {
29
if (env.webgpu.validateInputContent && (!inputs || inputs.length !== 1)) {
30
throw new Error('Pool ops requires 1 input.');
34
const getAdjustedPoolAttributesAndOutputShape = <AttributeType extends AveragePoolAttributes | MaxPoolAttributes>(
36
attributes: AttributeType,
37
isGlobalOperator: boolean,
38
): [AttributeType, number[]] => {
39
const isChannelsLast = attributes.format === 'NHWC';
40
const inputShapeAsChannelFirst = input.dims.slice();
42
inputShapeAsChannelFirst.splice(1, 0, inputShapeAsChannelFirst.pop()!); // Move channel to the second position.
44
const hasDilations = Object.hasOwnProperty.call(attributes, 'dilations');
45
const kernelShape = attributes.kernelShape.slice();
46
const strides = attributes.strides.slice();
47
const dilations: number[] = hasDilations ? (attributes as MaxPoolAttributes).dilations.slice() : [];
48
const pads = attributes.pads.slice();
49
PoolConvUtil.adjustPoolAttributes(isGlobalOperator, inputShapeAsChannelFirst, kernelShape, strides, dilations, pads);
51
const outputShapeAsChannelFirst = PoolConvUtil.computePoolOutputShape(
53
inputShapeAsChannelFirst,
61
const newAttributes = Object.assign({}, attributes);
63
Object.assign(newAttributes, { kernelShape, strides, pads, dilations, cacheKey: attributes.cacheKey });
65
Object.assign(newAttributes, { kernelShape, strides, pads, cacheKey: attributes.cacheKey });
67
const outputShapeAsChannelLast = outputShapeAsChannelFirst.slice();
68
outputShapeAsChannelLast.push(outputShapeAsChannelLast.splice(1, 1)[0]);
69
return [newAttributes, isChannelsLast ? outputShapeAsChannelLast : outputShapeAsChannelFirst];
72
const getUniformAndPadInfo = <AttributeType extends AveragePoolAttributes | MaxPoolAttributes>(
73
outputShape: readonly number[],
74
attributes: AttributeType,
75
): [ProgramUniform[], UniformsArrayType, boolean, boolean, boolean] => {
76
const isChannelsLast = attributes.format === 'NHWC';
77
const outputSize = ShapeUtil.size(outputShape);
78
const kernelSize = ShapeUtil.size(attributes.kernelShape);
79
const programUniforms: ProgramUniform[] = [
80
{ type: DataType.uint32, data: outputSize },
81
{ type: DataType.uint32, data: kernelSize },
83
const uniforms: UniformsArrayType = [
84
{ name: 'outputSize', type: 'u32' },
85
{ name: 'kernelSize', type: 'u32' },
87
if (attributes.kernelShape.length <= 2) {
88
const kw = attributes.kernelShape[attributes.kernelShape.length - 1];
89
const sw = attributes.strides[attributes.strides.length - 1];
90
const pwStart = attributes.pads[attributes.pads.length / 2 - 1];
91
const pwEnd = attributes.pads[attributes.pads.length - 1];
92
const pwStartEndNotZero = !!(pwStart + pwEnd);
94
{ type: DataType.uint32, data: kw },
95
{ type: DataType.uint32, data: sw },
96
{ type: DataType.uint32, data: pwStart },
97
{ type: DataType.uint32, data: pwEnd },
100
{ name: 'kw', type: 'u32' },
101
{ name: 'sw', type: 'u32' },
102
{ name: 'pwStart', type: 'u32' },
103
{ name: 'pwEnd', type: 'u32' },
106
let phStartEndNotZero = false;
107
if (attributes.kernelShape.length === 2) {
108
const kh = attributes.kernelShape[attributes.kernelShape.length - 2];
109
const sh = attributes.strides[attributes.strides.length - 2];
110
const phStart = attributes.pads[attributes.pads.length / 2 - 2];
111
const phEnd = attributes.pads[attributes.pads.length - 2];
112
phStartEndNotZero = !!(phStart + phEnd);
113
programUniforms.push(
114
{ type: DataType.uint32, data: kh },
115
{ type: DataType.uint32, data: sh },
116
{ type: DataType.uint32, data: phStart },
117
{ type: DataType.uint32, data: phEnd },
121
{ name: 'kh', type: 'u32' },
122
{ name: 'sh', type: 'u32' },
123
{ name: 'phStart', type: 'u32' },
124
{ name: 'phEnd', type: 'u32' },
127
return [programUniforms, uniforms, true, pwStartEndNotZero, phStartEndNotZero];
129
if (isChannelsLast) {
130
throw new Error('Pooling with kernelShape.length > 2 is not supported for NHWC format.');
132
const kernelStrides = ShapeUtil.computeStrides(attributes.kernelShape);
133
programUniforms.push(
134
{ type: DataType.uint32, data: kernelStrides },
135
{ type: DataType.uint32, data: attributes.pads },
136
{ type: DataType.uint32, data: attributes.strides },
139
{ name: 'kernelStrides', type: 'u32', length: kernelStrides.length },
140
{ name: 'pads', type: 'u32', length: attributes.pads.length },
141
{ name: 'strides', type: 'u32', length: attributes.strides.length },
144
const hasPads = attributes.pads.reduce((sum, cur) => sum + cur);
145
return [programUniforms, uniforms, !!hasPads, false, false];
149
const generatePoolingCode = <AttributeType extends AveragePoolAttributes | MaxPoolAttributes>(
150
shaderHelper: ShaderHelper,
153
outputShapeRank: number,
154
attributes: AttributeType,
158
uniforms: UniformsArrayType,
160
pwStartEndNotZero: boolean,
161
phStartEndNotZero: boolean,
163
const isChannelsLast = attributes.format === 'NHWC';
164
const dataType = x.type.value;
165
const output = outputVariable('output', x.type.tensor, outputShapeRank);
167
if (attributes.kernelShape.length <= 2) {
171
const dimIdxW = rank - (isChannelsLast ? 2 : 1);
172
if (pwStartEndNotZero) {
174
for (var i: u32 = 0u; i < uniforms.kw; i++) {
175
xIndices[${dimIdxW}] = indices[${dimIdxW}] * uniforms.sw - uniforms.pwStart + i;
176
if (xIndices[${dimIdxW}] < 0 || xIndices[${dimIdxW}]
177
>= uniforms.x_shape[${dimIdxW}]) {
181
let x_val = x[${x.indicesToOffset('xIndices')}];
186
for (var i: u32 = 0u; i < uniforms.kw; i++) {
187
xIndices[${dimIdxW}] = indices[${dimIdxW}] * uniforms.sw - uniforms.pwStart + i;
188
let x_val = x[${x.indicesToOffset('xIndices')}];
193
if (attributes.kernelShape.length === 2) {
194
const dimIdxH = rank - (isChannelsLast ? 3 : 2);
195
if (phStartEndNotZero) {
197
for (var j: u32 = 0u; j < uniforms.kh; j++) {
198
xIndices[${dimIdxH}] = indices[${dimIdxH}] * uniforms.sh - uniforms.phStart + j;
199
if (xIndices[${dimIdxH}] < 0 || xIndices[${dimIdxH}] >= uniforms.x_shape[${dimIdxH}]) {
200
pad += i32(uniforms.kw);
206
for (var j: u32 = 0u; j < uniforms.kh; j++) {
207
xIndices[${dimIdxH}] = indices[${dimIdxH}] * uniforms.sh - uniforms.phStart + j;
215
const poolingCode = `
216
${shaderHelper.registerUniforms(uniforms).declareVariables(x, output)}
218
${shaderHelper.mainStart()}
219
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
221
let indices = ${output.offsetToIndices('global_idx')};
222
var xIndices = ${output.offsetToIndices('global_idx')};
224
var value = ${dataType}(${start});
231
output[global_idx] = value;
235
if (isChannelsLast) {
236
throw new Error('Pooling with kernelShape.length > 2 is not supported for NHWC format.');
238
const stridesRank = attributes.kernelShape.length;
239
const padsRank = attributes.pads.length;
243
if (xIndices[j] >= uniforms.x_shape[j]) {
250
let x_val = x[${x.indicesToOffset('xIndices')}];
256
let x_val = x[${x.indicesToOffset('xIndices')}];
260
const poolingCode = `
261
${shaderHelper.registerUniforms(uniforms).declareVariables(x, output)}
263
${shaderHelper.mainStart()}
264
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
265
let indices = ${output.offsetToIndices('global_idx')};
266
var xIndices = ${output.offsetToIndices('global_idx')};
268
var offsets: array<u32, ${stridesRank}>;
270
var value = ${dataType}(${start});
274
for (var i: u32 = 0u; i < uniforms.kernelSize; i++) {
276
for (var j = 0u; j < ${stridesRank - 1}u; j++) {
277
offsets[j] = offset / ${getElementAt('uniforms.kernelStrides', 'j', stridesRank)};
278
offset -= offsets[j] * ${getElementAt('uniforms.kernelStrides', 'j', stridesRank)};
280
offsets[${stridesRank - 1}] = offset;
283
for (var j = ${rank - stridesRank}u; j < ${rank}u; j++) {
284
xIndices[j] = indices[j] * ${getElementAt(
286
`j - ${rank - stridesRank}u`,
289
+ offsets[j - ${rank - stridesRank}u] - ${getElementAt('uniforms.pads', 'j - 2u', padsRank)};
294
output[global_idx] = value;
300
export interface FormatAttributes {
301
readonly format: 'NHWC' | 'NCHW';
304
export interface PoolCommonAttributes extends FormatAttributes {
305
readonly autoPad: string;
306
readonly ceilMode: number;
307
readonly kernelShape: readonly number[];
308
readonly strides: readonly number[];
309
readonly pads: readonly number[];
312
const createShaderKeyFromAttributes = (attributes: PoolCommonAttributes): string =>
313
`${attributes.format};${attributes.ceilMode};${attributes.autoPad};${attributes.kernelShape.length}`;
315
const createAveragePoolShaderKeyFromAttributes = (attributes: AveragePoolAttributes): string =>
316
`${createShaderKeyFromAttributes(attributes)};${attributes.countIncludePad}`;
318
const createMaxPoolShaderKeyFromAttributes = (attributes: MaxPoolAttributes): string =>
319
`${createShaderKeyFromAttributes(attributes)};${attributes.storageOrder};${attributes.dilations}`;
321
const parsePoolCommonAttributes = (attributes: Record<string, unknown>): PoolCommonAttributes => ({
322
format: attributes.format as FormatAttributes['format'],
323
autoPad: ['NOTSET', 'VALID', 'SAME_UPPER', 'SAME_LOWER'][attributes.auto_pad as number],
324
ceilMode: attributes.ceil_mode as number,
325
kernelShape: attributes.kernel_shape as [number, number],
326
strides: attributes.strides as [number, number],
327
pads: attributes.pads as [number, number, number, number],
330
export interface AveragePoolAttributes extends PoolCommonAttributes, AttributeWithCacheKey {
331
readonly countIncludePad: boolean;
334
const createAveragePoolProgramInfo = (
337
isGlobalOperator: boolean,
338
attributes: AveragePoolAttributes,
340
const [adjustedAttributes, outputShape] = getAdjustedPoolAttributesAndOutputShape(
345
const x = inputVariable('x', input.dataType, input.dims.length);
346
const dataType = x.type.value;
348
const op1 = 'value += x_val;';
350
if (adjustedAttributes.countIncludePad) {
351
op2 += `value /= ${dataType}(uniforms.kernelSize);`;
353
op2 += `value /= ${dataType}(i32(uniforms.kernelSize) - pad);`;
355
const [programUniforms, uniforms, hasPads, pwStartEndNotZero, phStartEndNotZero] = getUniformAndPadInfo(
359
programUniforms.push(...createTensorShapeVariables(input.dims, outputShape));
360
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
364
hint: `${attributes.cacheKey};${hasPads};${pwStartEndNotZero};${phStartEndNotZero}`,
368
outputs: [{ dims: outputShape, dataType: input.dataType }],
369
dispatchGroup: { x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */) },
372
getShaderSource: (shaderHelper) =>
390
export const parseAveragePoolAttributes = (attributes: Record<string, unknown>): AveragePoolAttributes => {
391
const countIncludePad = (attributes.count_include_pad as number) === 0 ? false : true;
393
const attr = parsePoolCommonAttributes(attributes);
394
// TODO: support attribute 'ceil_mode'
395
if (attr.ceilMode !== 0) {
396
throw new Error('using ceil() in shape computation is not yet supported for AveragePool');
398
const averagePoolAttributes = { countIncludePad, ...attr, cacheKey: '' };
399
return { ...averagePoolAttributes, cacheKey: createAveragePoolShaderKeyFromAttributes(averagePoolAttributes) };
402
export const averagePool = (context: ComputeContext, attributes: AveragePoolAttributes): void => {
403
validateInputs(context.inputs);
404
context.compute(createAveragePoolProgramInfo('AveragePool', context.inputs[0], false, attributes));
407
const globalPoolAttributes = {
410
countIncludePad: false,
418
export const parseGlobalAveragePoolAttributes = (attributes: Record<string, unknown>): AveragePoolAttributes => {
419
const format = attributes.format as FormatAttributes['format'];
420
return { format, ...globalPoolAttributes, cacheKey: format };
423
export const globalAveragePool = (context: ComputeContext, attributes: AveragePoolAttributes): void => {
424
validateInputs(context.inputs);
425
context.compute(createAveragePoolProgramInfo('GlobalAveragePool', context.inputs[0], true, attributes));
428
export interface MaxPoolAttributes extends PoolCommonAttributes, AttributeWithCacheKey {
429
readonly storageOrder: number;
430
readonly dilations: number[];
433
const createMaxPoolProgramInfo = (
436
isGlobalOperator: boolean,
437
attributes: MaxPoolAttributes,
439
const [adjustedAttributes, outputShape] = getAdjustedPoolAttributesAndOutputShape(
445
value = max(x_val, value);
448
const x = inputVariable('x', input.dataType, input.dims.length);
449
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
450
const [programUniforms, uniforms, hasPads, pwStartEndNotZero, phStartEndNotZero] = getUniformAndPadInfo(
454
programUniforms.push(...createTensorShapeVariables(input.dims, outputShape));
458
hint: `${attributes.cacheKey};${hasPads};${pwStartEndNotZero};${phStartEndNotZero}`,
462
outputs: [{ dims: outputShape, dataType: input.dataType }],
463
dispatchGroup: { x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */) },
466
getShaderSource: (shaderHelper) =>
475
input.dataType === DataType.float16 ? -65504 : -1e5,
484
export const maxPool = (context: ComputeContext, attributes: MaxPoolAttributes): void => {
485
validateInputs(context.inputs);
486
context.compute(createMaxPoolProgramInfo('MaxPool', context.inputs[0], false, attributes));
489
export const parseMaxPoolAttributes = (attributes: Record<string, unknown>): MaxPoolAttributes => {
490
const storageOrder = attributes.storage_order as number;
491
const dilations = attributes.dilations as [number, number];
493
const attr = parsePoolCommonAttributes(attributes);
494
// TODO: support attribute 'ceil_mode' and 'storage_order'
495
if (storageOrder !== 0) {
496
throw new Error('column major storage order is not yet supported for MaxPool');
498
if (attr.ceilMode !== 0) {
499
throw new Error('using ceil() in shape computation is not yet supported for MaxPool');
501
const maxPoolAttributes = { storageOrder, dilations, ...attr, cacheKey: '' };
502
return { ...maxPoolAttributes, cacheKey: createMaxPoolShaderKeyFromAttributes(maxPoolAttributes) };
505
export const parseGlobalMaxPoolAttributes = (attributes: Record<string, unknown>): MaxPoolAttributes => {
506
const format = attributes.format as FormatAttributes['format'];
507
return { format, ...globalPoolAttributes, cacheKey: format };
510
export const globalMaxPool = (context: ComputeContext, attributes: MaxPoolAttributes): void => {
511
validateInputs(context.inputs);
512
context.compute(createMaxPoolProgramInfo('GlobalMaxPool', context.inputs[0], true, attributes));