1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
5
import { Graph } from '../../../graph';
6
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
7
import { Tensor } from '../../../tensor';
8
import { WebGLInferenceHandler } from '../inference-handler';
9
import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types';
11
import { createPackedConcatProgramInfoLoader } from './concat-packed';
13
export interface ConcatAttributes extends AttributeWithCacheKey {
14
readonly axis: number;
17
export const concat: OperatorImplementation<ConcatAttributes> = (
18
inferenceHandler: WebGLInferenceHandler,
20
attributes: ConcatAttributes,
22
validateInputs(inputs);
23
if (inferenceHandler.session.pack && inputs[0].dims.length > 1) {
24
const output = inferenceHandler.run(
25
createPackedConcatProgramInfoLoader(inferenceHandler, inputs, attributes),
30
const output = inferenceHandler.run(
31
createUnpackedConcatProgramInfoLoader(inferenceHandler, inputs, attributes),
38
const createUnpackedConcatProgramMetadata = (inputCount: number, cacheHint: string) => ({
40
inputNames: Array.from({ length: inputCount }, (_v, i) => `X${i}`),
41
inputTypes: Array(inputCount).fill(TextureType.unpacked),
45
const createUnpackedConcatProgramInfo = (
46
_handler: WebGLInferenceHandler,
47
metadata: ProgramMetadata,
51
const inputShape = inputs[0].dims.slice();
52
if (axis >= inputShape.length || axis < -1 * inputShape.length) {
53
throw new Error("axis specified for concat doesn't match input dimensionality");
56
axis = inputShape.length + axis;
58
// ensure all of the non-concatenated axes match each other
59
// calculate the shape of the output tensor while we do that
60
const outputShape = inputShape.slice(0);
61
for (let i = 1; i < inputs.length; i++) {
62
const dataNShape = inputs[i].dims.slice();
63
for (let axisIndex = 0; axisIndex < inputShape.length; axisIndex++) {
64
// add to the placeholder for computing output shape
65
if (axisIndex === axis) {
66
outputShape[axis] += dataNShape[axisIndex];
68
// ensure all non-cancatenated axes match each other
69
else if (inputShape[axisIndex] !== dataNShape[axisIndex]) {
70
throw new Error('non concat dimensions must match');
75
const rank = outputShape.length;
77
const sizeInConcatAxis = new Array<number>(inputs.length);
79
for (let i = 0; i < sizeInConcatAxis.length; ++i) {
80
previousSum += inputs[i].dims[axis];
81
sizeInConcatAxis[i] = previousSum;
84
let getTextureIndexWhereDataResidesMethod = '';
85
// in most cases linear search is sufficient, as in most scenarios, only 2 tensors are concatenated
86
if (inputs.length < 5) {
87
getTextureIndexWhereDataResidesMethod = getTextureIndexWhereDataResidesLinearSearch(sizeInConcatAxis);
89
getTextureIndexWhereDataResidesMethod = getTextureIndexWhereDataResidesBinarySearch(sizeInConcatAxis);
92
const fetchDataFromCorrectTextureMethod = getFetchDataFromCorrectTextureMethod(inputs.length, rank);
93
const getSizeInConcatAxisValueFromIndexMethod = getGetSizeInConcatAxisValueFromIndexMethod(sizeInConcatAxis);
94
const shaderSource = `
95
${fetchDataFromCorrectTextureMethod}
96
${getSizeInConcatAxisValueFromIndexMethod}
97
${getTextureIndexWhereDataResidesMethod}
98
float process(int indices[${rank}]) {
99
int textureIndex = getTextureWhereDataResides (indices[${axis}]);
101
if(textureIndex != 0) {
102
indices[${axis}] = indices[${axis}] - int(getSizeInConcatAxisValueFromIndex(textureIndex-int(1)));
105
return fetchDataFromCorrectTexture(textureIndex, indices);
109
output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked },
114
const createUnpackedConcatProgramInfoLoader = (
115
handler: WebGLInferenceHandler,
117
attributes: ConcatAttributes,
118
): ProgramInfoLoader => {
119
const metadata = createUnpackedConcatProgramMetadata(inputs.length, attributes.cacheKey);
120
return { ...metadata, get: () => createUnpackedConcatProgramInfo(handler, metadata, inputs, attributes.axis) };
123
const getTextureIndexWhereDataResidesLinearSearch = (sizeInConcatAxis: number[]): string => {
124
const searchAxis = sizeInConcatAxis.map(
125
(size, i) => `if(index<${size}) {return ${i};}
128
return `int getTextureWhereDataResides(int index) {
129
${searchAxis.join('')}
133
// TODO: Implement BinarySearch in GLSL
134
const getTextureIndexWhereDataResidesBinarySearch = (sizeInConcatAxis: number[]): string =>
135
getTextureIndexWhereDataResidesLinearSearch(sizeInConcatAxis);
137
const getFetchDataFromCorrectTextureMethod = (numberOfTensors: number, tensorRank: number) => {
138
const codeLines: string[] = [`float fetchDataFromCorrectTexture(int textureIndex, int indices[${tensorRank}]) {`];
139
for (let i = 0; i < numberOfTensors; ++i) {
141
codeLines.push('\t' + `if (textureIndex == ${i}) { return _X${i}(indices); }`);
142
} else if (i === numberOfTensors - 1) {
143
codeLines.push('\t' + `else { return _X${i}(indices); }`);
145
codeLines.push('\t' + `else if (textureIndex == ${i}) { return _X${i}(indices); }`);
148
codeLines.push('\t' + '}');
149
return codeLines.join('\n');
152
const getGetSizeInConcatAxisValueFromIndexMethod = (sizeInConcatAxis: number[]): string => {
153
const codeLines: string[] = ['int getSizeInConcatAxisValueFromIndex(int index) {'];
154
for (let i = 0; i < sizeInConcatAxis.length; ++i) {
156
codeLines.push('\t' + `if (index == ${i}) { return ${sizeInConcatAxis[i]}; }`);
157
} else if (i === sizeInConcatAxis.length - 1) {
158
codeLines.push('\t' + `else { return ${sizeInConcatAxis[i]}; }`);
160
codeLines.push('\t' + `else if (index == ${i}) { return ${sizeInConcatAxis[i]}; }`);
163
codeLines.push('\t' + '}');
165
return codeLines.join('\n');
168
export const parseConcatAttributes: OperatorInitialization<ConcatAttributes> = (node: Graph.Node): ConcatAttributes =>
169
createAttributeWithCacheKey({ axis: node.attributes.getInt('axis') });
171
const validateInputs = (inputs: Tensor[]): void => {
172
if (!inputs || inputs.length < 1) {
173
throw new Error('too few inputs');
176
const inputType = inputs[0].type;
177
const inputDimensionality = inputs[0].dims.length;
179
// TODO: Support string concat
180
if (inputType === 'string') {
181
throw new Error('string tensor is not supported yet');
184
for (const input of inputs) {
185
// make sure types of all inputs match
186
if (input.type !== inputType) {
187
throw new Error('input tensors should be one type');
190
// make sure the dimensionality of all inputs are the same
191
if (input.dims.length !== inputDimensionality) {
192
throw new Error('input tensors should have the same shape');