1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
8
import { ComputeContext, ProgramInfo, ProgramUniform, TensorInfo } from '../types';
11
createTensorShapeVariables,
19
export interface SplitAttributes extends AttributeWithCacheKey {
20
readonly axis: number;
21
readonly numOutputs: number;
22
readonly splitSizes: number[];
25
const validateInputs = (inputs: readonly TensorView[]): void => {
26
if (!inputs || inputs.length < 1) {
27
throw new Error('too few inputs');
31
const createSplitAttributesFromInputs = (
32
inputs: readonly TensorView[],
33
attributes: SplitAttributes,
34
): SplitAttributes => {
35
const splitSizes: number[] = [];
36
let numOutputs: number = attributes.numOutputs;
37
if (inputs[1].dims[0] > 0) {
38
inputs[1].getBigInt64Array().forEach((v) => splitSizes.push(Number(v)));
39
numOutputs = splitSizes.length;
41
return createAttributeWithCacheKey({ numOutputs, axis: attributes.axis, splitSizes });
44
const calculateOutputIndexImpl = (numberOfTensors: number): string => `
45
fn calculateOutputIndex(index: u32) -> u32 {
46
for (var i: u32 = 0u; i < ${numberOfTensors}u; i += 1u ) {
47
if (index < ${getElementAt('uniforms.size_in_split_axis', 'i', numberOfTensors)}) {
51
return ${numberOfTensors}u;
53
const writeBufferDataImpl = (outputs: readonly IndicesHelper[]) => {
54
const numberOfTensors = outputs.length;
55
const codeLines: string[] = [];
56
for (let i = 0; i < numberOfTensors; ++i) {
57
const returnSnippet = outputs[i].setByIndices('indices', 'input[global_idx]');
58
if (numberOfTensors === 1) {
59
codeLines.push(returnSnippet);
61
codeLines.push(`if (output_number == ${i}u) { ${returnSnippet} }`);
62
} else if (i === numberOfTensors - 1) {
63
codeLines.push(`else { ${returnSnippet} }`);
65
codeLines.push(`else if (output_number == ${i}) { ${returnSnippet} }`);
69
fn writeBufferData(output_number: u32, indices: ${outputs[0].type.indices}, global_idx: u32) {
70
${codeLines.join('\n')}
74
const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfo => {
75
const inputShape = inputs[0].dims;
76
const inputSize = ShapeUtil.size(inputShape);
77
const dataType = inputs[0].dataType;
78
const axis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length);
79
const outputs = new Array<IndicesHelper>(attributes.numOutputs);
80
const input = inputVariable('input', dataType, inputShape.length);
81
const sizeInSplitAxis = new Array<number>(attributes.numOutputs);
82
const outputsTensorInfo: TensorInfo[] = [];
83
const outputShapes: number[][] = [];
85
const programUniforms: ProgramUniform[] = [{ type: DataType.uint32, data: inputSize }];
86
for (let i = 0; i < attributes.numOutputs; i++) {
87
previousSum += attributes.splitSizes[i];
88
sizeInSplitAxis[i] = previousSum;
89
const outputShape = inputShape.slice();
90
outputShape[axis] = attributes.splitSizes[i];
91
outputShapes.push(outputShape);
92
outputs[i] = outputVariable(`output${i}`, dataType, outputShape.length);
93
outputsTensorInfo.push({ dims: outputShapes[i], dataType: inputs[0].dataType });
96
{ type: DataType.uint32, data: sizeInSplitAxis },
97
...createTensorShapeVariables(inputShape, ...outputShapes),
99
const getShaderSource = (shaderHelper: ShaderHelper) => `
101
.registerUniform('input_size', 'u32')
102
.registerUniform('size_in_split_axis', 'u32', sizeInSplitAxis.length)
103
.declareVariables(input, ...outputs)}
104
${calculateOutputIndexImpl(sizeInSplitAxis.length)}
105
${writeBufferDataImpl(outputs)}
107
${shaderHelper.mainStart()}
108
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.input_size')}
110
var indices = ${input.offsetToIndices('global_idx')};
111
var index = ${input.indicesGet('indices', axis)};
112
let output_number = calculateOutputIndex(index);
113
if (output_number != 0) {
114
index -= ${getElementAt('uniforms.size_in_split_axis', 'output_number - 1u', sizeInSplitAxis.length)};
115
${input.indicesSet('indices', axis, 'index')};
117
writeBufferData(output_number, indices, global_idx);
121
shaderCache: { hint: attributes.cacheKey, inputDependencies: ['rank'] },
124
outputs: outputsTensorInfo,
125
dispatchGroup: { x: Math.ceil(inputSize / 64 /* workgroup size */) },
131
export const split = (context: ComputeContext, attributes: SplitAttributes): void => {
132
validateInputs(context.inputs);
133
const updatedAttributes =
134
context.inputs.length === 1 ? attributes : createSplitAttributesFromInputs(context.inputs, attributes);
135
context.compute(createSplitProgramInfo(context.inputs, updatedAttributes), { inputs: [0] });
138
export const parseSplitAttributes = (attributes: Record<string, unknown>): SplitAttributes => {
139
const axis = attributes.axis as number;
140
const splitSizes: number[] = attributes.splitSizes as number[];
141
const numOutputs = (attributes.numOutputs as number) < 0 ? splitSizes.length : (attributes.numOutputs as number);
142
if (numOutputs !== splitSizes.length) {
143
throw new Error('numOutputs and splitSizes lengh must be equal');
145
return createAttributeWithCacheKey({ axis, numOutputs, splitSizes });