1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
8
import { ComputeContext, ProgramInfo, ProgramShaderCacheInfo } from '../types';
10
import { createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper } from './common';
14
reduceLogSumExpShared,
21
reduceSumSquareShared,
22
} from './reduce-shared';
24
const validateInputs = (inputs: readonly TensorView[]): void => {
25
if (!inputs || inputs.length === 0 || inputs.length > 2) {
26
throw new Error('Reduce op requires 1 or 2 inputs.');
29
if (inputs.length === 2 && inputs[1].dims.length !== 1) {
30
throw new Error('Invalid axes input dims.');
34
export interface ReduceAttributes extends AttributeWithCacheKey {
36
noopWithEmptyAxes: boolean;
40
export type ReduceOp = (
42
output: IndicesHelper,
43
axes: readonly number[],
44
) => [string, string, string, string, ...string[]];
46
const noOp: ReduceOp = (input) => ['', '', `var value = ${input.getByIndices('input_indices')};`, ''];
47
export const createReduceProgramInfo = (
49
shaderCache: ProgramShaderCacheInfo,
50
inputs: readonly TensorView[],
53
outputDataType: DataType,
55
noopWithEmptyAxes = false,
57
const outputShape: number[] = [];
58
const inputShape = inputs[0].dims;
59
const inputRank = inputShape.length;
60
const axes = ShapeUtil.normalizeAxes(axesInput, inputRank);
61
const reduceOnAllAxes = !noopWithEmptyAxes && axes.length === 0;
62
inputShape.forEach((d, i) => {
63
if (reduceOnAllAxes || axes.indexOf(i) >= 0) {
66
} // else { // skip this axis}
71
const outputRank = outputShape.length;
72
const outputSize = ShapeUtil.size(outputShape);
73
const getShaderSource = (shaderHelper: ShaderHelper) => {
74
const idxCopy: string[] = []; // copy output indexes to input indexes
76
const input = inputVariable('_A', inputs[0].dataType, inputRank);
77
const output = outputVariable('output', outputDataType, outputRank);
78
const ops = reduceOp(input, output, axes);
79
let reduceOps = ops[2];
81
for (let k = 0, l = 0; k < inputRank; k++) {
82
// if this axis is reduced
83
if (reduceOnAllAxes || axes.indexOf(k) >= 0) {
87
// loop over the d-th axis
88
reduceOps = `for(var j${k}: u32 = 0; j${k} < ${inputShape[k]}; j${k}++) {
89
${ops[2].includes('last_index') ? `let last_index = j${k};` : ''}
90
${input.indicesSet('input_indices', k, `j${k}`)}
94
idxCopy.push(`${input.indicesSet('input_indices', k, output.indicesGet('output_indices', l))};`);
100
${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)}
102
${shaderHelper.mainStart()}
103
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
104
var input_indices: ${input.type.indices};
105
let output_indices = ${output.offsetToIndices('global_idx')};
107
${idxCopy.join('\n')}
108
${ops[0]} // init ops for reduce max/min
112
${ops.length === 4 ? output.setByOffset('global_idx', 'value') : ops.slice(4).join('\n')}
121
outputs: [{ dims: outputShape, dataType: outputDataType }],
122
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
124
{ type: DataType.uint32, data: outputSize },
125
...createTensorShapeVariables(inputShape, outputShape),
131
export const createReduceAttributesFromInputs = (
132
inputs: readonly TensorView[],
133
attributes: ReduceAttributes,
134
): ReduceAttributes => {
135
const axes: number[] = [];
136
if (inputs[1].dims[0] > 0) {
137
inputs[1].getBigInt64Array().forEach((v) => axes.push(Number(v)));
139
return createAttributeWithCacheKey({
141
keepDims: attributes.keepDims,
142
noopWithEmptyAxes: attributes.noopWithEmptyAxes,
146
const runReduceProgram = (
147
context: ComputeContext,
149
attributes: ReduceAttributes,
152
const inputs = context.inputs;
153
const updatedAttributes: ReduceAttributes =
154
inputs.length === 1 ? attributes : createReduceAttributesFromInputs(inputs, attributes);
157
createReduceProgramInfo(
159
{ hint: updatedAttributes.cacheKey, inputDependencies: ['rank'] },
161
updatedAttributes.noopWithEmptyAxes && updatedAttributes.axes.length === 0 ? noOp : reduceOp,
162
updatedAttributes.axes,
164
updatedAttributes.keepDims,
165
updatedAttributes.noopWithEmptyAxes,
171
const reduceLogSumNaive = (context: ComputeContext, attributes: ReduceAttributes): void => {
172
validateInputs(context.inputs);
173
const reduceOp: ReduceOp = (input, output) => [
174
`var value = ${output.type.storage}(0);`,
176
`value += ${input.getByIndices('input_indices')};`,
177
'value = log(value);',
179
runReduceProgram(context, 'ReduceLogSum', attributes, reduceOp);
182
const reduceL1Naive = (context: ComputeContext, attributes: ReduceAttributes): void => {
183
validateInputs(context.inputs);
184
const reduceOp: ReduceOp = (input, output) => [
185
`var value = ${output.type.storage}(0);`,
187
`value += abs(${input.getByIndices('input_indices')});`,
190
runReduceProgram(context, 'ReduceL1', attributes, reduceOp);
193
const reduceL2Naive = (context: ComputeContext, attributes: ReduceAttributes): void => {
194
validateInputs(context.inputs);
195
const reduceOp: ReduceOp = (input, output) => [
196
`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`,
198
`t = ${input.getByIndices('input_indices')}; value += (t * t);`,
199
'value = sqrt(value);',
201
runReduceProgram(context, 'ReduceL2', attributes, reduceOp);
204
const reduceLogSumExpNaive = (context: ComputeContext, attributes: ReduceAttributes): void => {
205
validateInputs(context.inputs);
206
const reduceOp: ReduceOp = (input, output) => [
207
`var value = ${output.type.storage}(0);`,
209
`value += exp(${input.getByIndices('input_indices')});`,
210
'value = log(value);',
212
runReduceProgram(context, 'ReduceLogSumExp', attributes, reduceOp);
215
const reduceMaxNaive = (context: ComputeContext, attributes: ReduceAttributes): void => {
216
validateInputs(context.inputs);
217
const reduceOp: ReduceOp = (input, _output, axes) => {
219
for (let k = 0; k < input.rank; k++) {
220
if (axes.indexOf(k) >= 0 || axes.length === 0) {
221
idxZero.push(input.indicesSet('input_indices', k, 0));
226
`${idxZero.join('\n')}`,
227
`var value = ${input.getByIndices('input_indices')};`,
228
`value = max(value, ${input.getByIndices('input_indices')});`,
232
runReduceProgram(context, 'ReduceMax', attributes, reduceOp);
235
const reduceMeanNaive = (context: ComputeContext, attributes: ReduceAttributes): void => {
236
validateInputs(context.inputs);
237
const reduceOp: ReduceOp = (input, output, axes) => {
239
for (let k = 0; k < input.rank; k++) {
240
if (axes.indexOf(k) >= 0 || axes.length === 0) {
241
// TODO: this depends on the input dims. If we want to use uniform, this need to be updated.
242
size *= context.inputs[0].dims[k];
249
`sum += f32(${input.getByIndices('input_indices')});`,
250
`let value = ${output.type.value}(sum / ${size});`,
253
runReduceProgram(context, 'ReduceMean', attributes, reduceOp);
256
const reduceMinNaive = (context: ComputeContext, attributes: ReduceAttributes): void => {
257
validateInputs(context.inputs);
258
const reduceOp: ReduceOp = (input, _output, axes) => {
260
for (let k = 0; k < input.rank; k++) {
261
if (axes.indexOf(k) >= 0 || axes.length === 0) {
262
idxZero.push(`input_indices[${k}] = 0;`); // first element
267
`${idxZero.join('\n')}`,
268
`var value = ${input.getByIndices('input_indices')};`,
269
`value = min(value, ${input.getByIndices('input_indices')});`,
273
runReduceProgram(context, 'ReduceMin', attributes, reduceOp);
276
const reduceProdNaive = (context: ComputeContext, attributes: ReduceAttributes): void => {
277
validateInputs(context.inputs);
278
const reduceOp: ReduceOp = (input, output) => [
279
`var value = ${output.type.storage}(1);`,
281
`value *= ${input.getByIndices('input_indices')};`,
284
runReduceProgram(context, 'ReduceProd', attributes, reduceOp);
287
const reduceSumNaive = (context: ComputeContext, attributes: ReduceAttributes): void => {
288
validateInputs(context.inputs);
289
const reduceOp: ReduceOp = (input, output) => [
290
`var value = ${output.type.storage}(0);`,
292
`value += ${input.getByIndices('input_indices')};`,
295
runReduceProgram(context, 'ReduceSum', attributes, reduceOp);
298
const reduceSumSquareNaive = (context: ComputeContext, attributes: ReduceAttributes): void => {
299
validateInputs(context.inputs);
300
const reduceOp: ReduceOp = (input, output) => [
301
`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`,
303
`t = ${input.getByIndices('input_indices')}; value += t * t;`,
306
runReduceProgram(context, 'ReduceSumSquare', attributes, reduceOp);
309
const useNaiveReduceMethod = (
310
shape: readonly number[],
311
axes: readonly number[],
312
noopWithEmptyAxes: boolean,
314
if (axes.length === 0) {
315
return noopWithEmptyAxes;
320
for (let dim = 0; dim < axes.length; dim++) {
321
if (axes.indexOf(dim) === -1) {
322
outputSize *= shape[dim];
324
reduceSize *= shape[dim];
328
// The condition data is very rough, although considering the count of Execution Unit (EU), the potential
329
// work groups in a EU and the counts of loops in the naive and shared methods, also doing experiments
331
return reduceSize < 32 && outputSize > 1024;
334
export const reduceMean = (context: ComputeContext, attributes: ReduceAttributes): void => {
335
if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) {
336
reduceMeanNaive(context, attributes);
338
reduceMeanShared(context, attributes);
342
export const reduceL1 = (context: ComputeContext, attributes: ReduceAttributes): void => {
343
if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) {
344
reduceL1Naive(context, attributes);
346
reduceL1Shared(context, attributes);
350
export const reduceL2 = (context: ComputeContext, attributes: ReduceAttributes): void => {
351
if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) {
352
reduceL2Naive(context, attributes);
354
reduceL2Shared(context, attributes);
358
export const reduceLogSumExp = (context: ComputeContext, attributes: ReduceAttributes): void => {
359
if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) {
360
reduceLogSumExpNaive(context, attributes);
362
reduceLogSumExpShared(context, attributes);
366
export const reduceMax = (context: ComputeContext, attributes: ReduceAttributes): void => {
367
if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) {
368
reduceMaxNaive(context, attributes);
370
reduceMaxShared(context, attributes);
374
export const reduceMin = (context: ComputeContext, attributes: ReduceAttributes): void => {
375
if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) {
376
reduceMinNaive(context, attributes);
378
reduceMinShared(context, attributes);
382
export const reduceProd = (context: ComputeContext, attributes: ReduceAttributes): void => {
383
if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) {
384
reduceProdNaive(context, attributes);
386
reduceProdShared(context, attributes);
390
export const reduceSum = (context: ComputeContext, attributes: ReduceAttributes): void => {
391
if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) {
392
reduceSumNaive(context, attributes);
394
reduceSumShared(context, attributes);
398
export const reduceSumSquare = (context: ComputeContext, attributes: ReduceAttributes): void => {
399
if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) {
400
reduceSumSquareNaive(context, attributes);
402
reduceSumSquareShared(context, attributes);
406
export const reduceLogSum = (context: ComputeContext, attributes: ReduceAttributes): void => {
407
if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) {
408
reduceLogSumNaive(context, attributes);
410
reduceLogSumShared(context, attributes);