1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { ComputeContext, ProgramInfo, ProgramShaderCacheInfo } from '../types';
9
import { inputVariable, outputVariable, ShaderHelper } from './common';
10
import { createReduceAttributesFromInputs, ReduceAttributes } from './reduce';
11
import { createTransposeProgramInfo } from './transpose';
13
const reduceOps: { [key: string]: string } = {
14
max: 'select(bestValue, candidate, candidate > bestValue)',
15
min: 'select(bestValue, candidate, candidate < bestValue)',
16
mean: 'bestValue + candidate',
17
sum: 'bestValue + candidate',
18
prod: 'bestValue * candidate',
19
sumSquare: 'bestValue + candidate * candidate',
20
logSumExp: 'bestValue + exp(candidate)',
21
l1: 'bestValue + abs(candidate)',
22
l2: 'bestValue + candidate * candidate',
23
logSum: 'bestValue + candidate',
26
const reduceSharedOps: { [key: string]: string } = {
27
max: 'select(bestValue, candidate, candidate > bestValue)',
28
min: 'select(bestValue, candidate, candidate < bestValue)',
29
mean: 'bestValue + candidate',
30
sum: 'bestValue + candidate',
31
prod: 'bestValue * candidate',
32
sumSquare: 'bestValue + candidate',
33
logSumExp: 'bestValue + candidate',
34
l1: 'bestValue + candidate',
35
l2: 'bestValue + candidate',
36
logSum: 'bestValue + candidate',
39
const reduceInitValues: { [key: string]: string } = {
52
const reduceOutputValues: { [key: string]: string } = {
57
sumSquare: 'bestValue',
58
logSumExp: 'log(bestValue)',
60
l2: 'sqrt(bestValue)',
61
logSum: 'log(bestValue)',
64
const getInnerMostAxes = (numInnerAxes: number, rank: number): number[] => {
66
for (let i = rank - numInnerAxes; i < rank; ++i) {
72
const computeOutAndReduceShapes = (shape: readonly number[], axes: readonly number[]): [number[], number[]] => {
73
const outputShape = [];
74
const rank = shape.length;
75
for (let dim = 0; dim < rank; dim++) {
76
if (axes.indexOf(dim) === -1) {
77
outputShape.push(shape[dim]);
80
const reduceShape = axes.map((dim) => shape[dim]);
81
return [outputShape, reduceShape];
84
const expandShapeToKeepDim = (shape: number[], axes: number[]): number[] => {
85
const rank = shape.length + axes.length;
86
const expandShape = [];
88
for (let dim = 0; dim < rank; dim++) {
89
if (axes.indexOf(dim) === -1) {
90
expandShape.push(shape[shapeIdx++]);
98
const areAxesInnerMostDims = (axes: number[], rank: number): boolean => {
99
for (let i = 0; i < axes.length; ++i) {
100
if (axes[axes.length - i - 1] !== rank - 1 - i) {
107
const getAxesPermutation = (axes: number[], rank: number): number[] => {
109
if (!areAxesInnerMostDims(axes, rank)) {
110
for (let i = 0; i < rank; ++i) {
111
if (axes.indexOf(i) === -1) {
115
axes.forEach((axis) => res.push(axis));
120
export const createReduceSharedProgramInfo = (
122
shaderCache: ProgramShaderCacheInfo,
123
inputs: readonly TensorView[],
125
outputDataType: DataType,
126
outputShape: number[],
127
reduceShape: number[],
129
const inputShape = inputs[0].dims;
131
const outputSize = ShapeUtil.size(outputShape);
132
const reduceSize = ShapeUtil.size(reduceShape);
134
const input = inputVariable('_A', inputs[0].dataType, inputShape);
135
const output = outputVariable('output', outputDataType, outputShape);
137
const workgroupSize = 32;
139
const sharedMemorySnippet = `
140
var<workgroup> aBestValues : array<f32, ${workgroupSize}>;
143
const getShaderSource = (shaderHelper: ShaderHelper) => `
144
${shaderHelper.registerUniform('reduceSize', 'u32').declareVariables(input, output)}
145
${sharedMemorySnippet}
146
fn DIV_CEIL(a : u32, b : u32) -> u32 {
147
return ((a - 1u) / b + 1u);
149
${shaderHelper.mainStart(workgroupSize)}
151
let outputIndex = global_idx / ${workgroupSize};
152
let offset = outputIndex * uniforms.reduceSize;
154
var bestValue = f32(${reduceInitValues[reduceType]});
155
let Length = uniforms.reduceSize;
156
for (var k = local_idx; k < Length; k = k + ${workgroupSize}) {
157
let candidate = f32(${input.getByOffset('offset + k')});
158
bestValue = ${reduceOps[reduceType]};
160
aBestValues[local_idx] = bestValue;
163
var reduceSize = min(Length, ${workgroupSize}u);
164
for (var currentSize = reduceSize / 2u; reduceSize > 1u;
165
currentSize = reduceSize / 2u) {
166
let interval = DIV_CEIL(reduceSize, 2u);
167
if (local_idx < currentSize) {
168
let candidate = aBestValues[local_idx + interval];
169
bestValue = ${reduceSharedOps[reduceType]};
170
aBestValues[local_idx] = bestValue;
172
reduceSize = interval;
176
if (local_idx == 0u) {
177
${output.setByOffset(
180
reduceType === 'mean'
181
? `${output.type.storage}(bestValue / f32(uniforms.reduceSize))`
182
: `${output.type.storage}(${reduceOutputValues[reduceType]})`
188
// One work group is responsible for only one element of output.
194
outputs: [{ dims: outputShape, dataType: outputDataType }],
195
dispatchGroup: { x: outputSize },
196
programUniforms: [{ type: DataType.uint32, data: reduceSize }],
201
const reduceCommon = (
202
context: ComputeContext,
204
attributes: ReduceAttributes,
205
reduceType: 'sum' | 'sumSquare' | 'prod' | 'min' | 'max' | 'mean' | 'logSumExp' | 'l1' | 'l2' | 'logSum',
207
const updatedAttributes: ReduceAttributes =
208
context.inputs.length === 1 ? attributes : createReduceAttributesFromInputs(context.inputs, attributes);
210
let updatedAxes = updatedAttributes.axes;
211
if (updatedAxes.length === 0 && !updatedAttributes.noopWithEmptyAxes) {
212
updatedAxes = context.inputs[0].dims.map((_dim, i) => i);
214
const normalizeAxes = ShapeUtil.normalizeAxes(updatedAxes, context.inputs[0].dims.length);
216
let axes = normalizeAxes;
217
let input = context.inputs[0];
218
const permutedAxes = getAxesPermutation(axes, context.inputs[0].dims.length);
219
if (permutedAxes.length > 0) {
220
input = context.compute(createTransposeProgramInfo(context.inputs[0], permutedAxes), {
224
axes = getInnerMostAxes(axes.length, input.dims.length);
227
const [outputShape, reduceShape] = computeOutAndReduceShapes(input.dims, axes);
228
let finalOutputShape = outputShape;
229
if (updatedAttributes.keepDims) {
230
finalOutputShape = expandShapeToKeepDim(outputShape, normalizeAxes);
234
createReduceSharedProgramInfo(
236
{ hint: updatedAttributes.cacheKey, inputDependencies: ['type'] },
239
context.inputs[0].dataType,
247
export const reduceMeanShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
248
reduceCommon(context, 'ReduceMeanShared', attributes, 'mean');
251
export const reduceL1Shared = (context: ComputeContext, attributes: ReduceAttributes): void => {
252
reduceCommon(context, 'ReduceL1Shared', attributes, 'l1');
255
export const reduceL2Shared = (context: ComputeContext, attributes: ReduceAttributes): void => {
256
reduceCommon(context, 'ReduceL2Shared', attributes, 'l2');
259
export const reduceLogSumExpShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
260
reduceCommon(context, 'ReduceLogSumExpShared', attributes, 'logSumExp');
263
export const reduceMaxShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
264
reduceCommon(context, 'ReduceMaxShared', attributes, 'max');
267
export const reduceMinShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
268
reduceCommon(context, 'ReduceMinShared', attributes, 'min');
271
export const reduceProdShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
272
reduceCommon(context, 'ReduceProdShared', attributes, 'prod');
275
export const reduceSumShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
276
reduceCommon(context, 'ReduceSumShared', attributes, 'sum');
279
export const reduceSumSquareShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
280
reduceCommon(context, 'ReduceSumSquareShared', attributes, 'sumSquare');
283
export const reduceLogSumShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
284
reduceCommon(context, 'ReduceLogSumShared', attributes, 'logSum');