1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
5
import { Graph } from '../../../graph';
6
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
7
import { Tensor } from '../../../tensor';
8
import { ShapeUtil, SplitUtil } from '../../../util';
9
import { WebGLInferenceHandler } from '../inference-handler';
10
import { ProgramInfo, TextureType } from '../types';
12
export interface SplitAttributes extends AttributeWithCacheKey {
13
readonly axis: number;
14
readonly split: number[];
15
readonly numOutputs: number;
18
const splitProgramMetadata = {
21
inputTypes: [TextureType.unpacked],
24
export const split: OperatorImplementation<SplitAttributes> = (
25
inferenceHandler: WebGLInferenceHandler,
27
attributes: SplitAttributes,
29
validateInputs(inputs);
31
const axis = ShapeUtil.normalizeAxis(attributes.axis, inputs[0].dims.length);
32
const count = getProgramCount(inferenceHandler, inputs, axis, attributes);
33
const output: Tensor[] = [];
34
for (let i = 0; i < count; ++i) {
38
...splitProgramMetadata,
39
cacheHint: `${attributes.cacheKey};${i}`,
40
get: () => createSplitProgramInfo(inferenceHandler, inputs[0], attributes, axis, i),
50
export const parseSplitAttributes: OperatorInitialization<SplitAttributes> = (node: Graph.Node): SplitAttributes => {
51
const axis = node.attributes.getInt('axis', 0);
52
const split = node.attributes.getInts('split', []);
53
const numOutputs = node.outputs.length;
54
return createAttributeWithCacheKey({ axis, split, numOutputs });
57
const getProgramCount = (
58
_inferenceHandler: WebGLInferenceHandler,
61
attributes: SplitAttributes,
63
const [, offsets] = SplitUtil.splitShape(inputs[0].dims, axis, attributes.split, attributes.numOutputs);
64
return offsets.length;
67
const createSplitProgramInfo = (
68
_inferenceHandler: WebGLInferenceHandler,
70
attributes: SplitAttributes,
74
const [shapes, offsets] = SplitUtil.splitShape(input.dims, axis, attributes.split, attributes.numOutputs);
75
const offset = offsets[index];
76
const outputShape = shapes[index];
77
const rank = outputShape.length;
78
const shaderSource = `
79
float process(int indices[${rank}]) {
80
indices[${axis}] += ${offset};
85
...splitProgramMetadata,
86
cacheHint: `${attributes.cacheKey}:${index}`,
87
output: { dims: outputShape, type: input.type, textureType: TextureType.unpacked },
92
const validateInputs = (inputs: Tensor[]): void => {
93
if (!inputs || inputs.length !== 1) {
94
throw new Error('Split requires one input.');
98
inputs[0].type !== 'int8' &&
99
inputs[0].type !== 'uint8' &&
100
inputs[0].type !== 'int16' &&
101
inputs[0].type !== 'uint16' &&
102
inputs[0].type !== 'int32' &&
103
inputs[0].type !== 'uint32' &&
104
inputs[0].type !== 'float32' &&
105
inputs[0].type !== 'float64' &&
106
inputs[0].type !== 'bool'
108
throw new Error('Invalid input type.');