1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
8
import { ComputeContext, ProgramInfo } from '../types';
10
import { createTensorShapeVariables, getElementAt, inputVariable, outputVariable, ShaderHelper } from './common';
12
export interface CumSumAttributes extends AttributeWithCacheKey {
13
readonly exclusive: boolean;
14
readonly reverse: boolean;
16
const createCumsumProgramInfo = (
18
inputShape: readonly number[],
19
axisInput: TensorView,
20
attributes: CumSumAttributes,
22
const outputSize = ShapeUtil.size(inputShape); // outputShape is same as inputShape.
23
const rank = inputShape.length; // input/output rank
24
const input = inputVariable('input', inputType, rank);
25
const output = outputVariable('output', inputType, rank);
27
axisInput.dataType === DataType.int32 ? axisInput.getInt32Array()[0] : Number(axisInput.getBigInt64Array()[0]);
28
const axis = ShapeUtil.normalizeAxis(axisValue, rank);
29
const getShaderSource = (shaderHelper: ShaderHelper) => {
30
const index = ` i32(${input.indicesGet('inputIndices', 'uniforms.axis')}) `;
31
const max = getElementAt('uniforms.input_shape', 'uniforms.axis', rank);
32
const lowerLimit = attributes.reverse ? index + (attributes.exclusive ? ' + 1' : '') : '0';
33
const upperLimit = attributes.reverse ? max : index + (attributes.exclusive ? '' : ' + 1');
36
.registerUniform('outputSize', 'u32')
37
.registerUniform('axis', 'u32')
38
.declareVariables(input, output)}
39
${shaderHelper.mainStart()}
40
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
41
var inputIndices = ${output.offsetToIndices('global_idx')};
42
var sum = ${output.type.value}(0);
43
let first : i32 = ${lowerLimit};
44
let last : i32 = ${upperLimit};
45
for (var i : i32 = first; i < last; i++) {
46
${input.indicesSet('inputIndices', 'uniforms.axis', 'u32(i)')};
47
sum = sum + ${input.getByIndices('inputIndices')};
49
${output.setByOffset('global_idx', 'sum')};
54
shaderCache: { hint: attributes.cacheKey, inputDependencies: ['rank'] },
56
outputs: [{ dims: inputShape, dataType: inputType }],
57
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
59
{ type: DataType.uint32, data: outputSize },
60
{ type: DataType.uint32, data: axis },
61
...createTensorShapeVariables(inputShape, inputShape),
68
export const cumsum = (context: ComputeContext, attributes: CumSumAttributes): void => {
69
const inputShape = context.inputs[0].dims;
70
const inputType = context.inputs[0].dataType;
71
const axis = context.inputs[1];
72
context.compute(createCumsumProgramInfo(inputType, inputShape, axis, attributes), { inputs: [0] });
75
export const parseCumSumAttributes = (attributes: Record<string, unknown>): CumSumAttributes => {
76
const exclusive = (attributes.exclusive as number) === 1;
77
const reverse = (attributes.reverse as number) === 1;
78
return createAttributeWithCacheKey({ exclusive, reverse });