1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
5
import { Graph } from '../../../graph';
6
import { NUMBER_TYPES, OperatorImplementation, OperatorInitialization } from '../../../operators';
7
import { Tensor } from '../../../tensor';
8
import { ShapeUtil } from '../../../util';
9
import { WebGLInferenceHandler } from '../inference-handler';
10
import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types';
12
interface GatherAttributes extends AttributeWithCacheKey {
13
readonly axis: number;
16
export const gather: OperatorImplementation<GatherAttributes> = (
17
inferenceHandler: WebGLInferenceHandler,
19
attributes: GatherAttributes,
21
validateInputs(inputs, attributes.axis);
22
const output = inferenceHandler.run(createGatherProgramInfoLoader(inferenceHandler, inputs, attributes), inputs);
26
export const parseGatherAttributes: OperatorInitialization<GatherAttributes> = (node: Graph.Node): GatherAttributes =>
27
createAttributeWithCacheKey({ axis: node.attributes.getInt('axis', 0) });
29
const gatherProgramMetadata = {
31
inputNames: ['A', 'B'],
32
inputTypes: [TextureType.unpacked, TextureType.unpacked],
35
const createGatherProgramInfo = (
36
_handler: WebGLInferenceHandler,
37
metadata: ProgramMetadata,
41
const inputShape = inputs[0].dims.slice();
42
const indexDataShape = inputs[1].dims.slice();
43
const outputShape = new Array(inputShape.length + indexDataShape.length - 1);
45
axis = ShapeUtil.normalizeAxis(axis, inputShape.length);
46
const indexCopyOps: string[] = [];
47
for (let i = 0; i < outputShape.length; i++) {
48
// outputShape is divided into three parts: A, B, C
49
// |0 axis| axis + indexDataShape.length | end|
52
// inputIdx: [A, inputs[1][B], C]
55
outputShape[i] = inputShape[i];
56
indexCopyOps.push(`inputIdx[${i}] = outputIdx[${i}];`);
58
if (i < axis + indexDataShape.length) {
60
outputShape[i] = indexDataShape[i - axis];
61
indexCopyOps.push(`indexDataIdx[${i - axis}] = outputIdx[${i}];`);
64
outputShape[i] = inputShape[i - indexDataShape.length + 1]; // skip 1 for axis
65
indexCopyOps.push(`inputIdx[${i - indexDataShape.length + 1}] = outputIdx[${i}];`);
70
const orank = outputShape.length || 1;
71
const irank = inputShape.length;
72
const iDrank = indexDataShape.length || 1;
73
const shaderSource = `
74
float process(int outputIdx[${orank}]) {
75
int inputIdx[${irank}];
76
int indexDataIdx[${iDrank}];
78
${indexCopyOps.join('\n ')}
79
int idx = int(_B(indexDataIdx));
80
inputIdx[${axis}] = idx < 0 ? idx + ${inputShape[axis]} : idx;
85
output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked },
90
const createGatherProgramInfoLoader = (
91
handler: WebGLInferenceHandler,
93
attributes: GatherAttributes,
94
): ProgramInfoLoader => {
95
const metadata = { ...gatherProgramMetadata, cacheHint: attributes.cacheKey };
96
return { ...metadata, get: () => createGatherProgramInfo(handler, metadata, inputs, attributes.axis) };
99
const validateInputs = (inputs: Tensor[], axis: number): void => {
100
if (!inputs || inputs.length !== 2) {
101
throw new Error('Gather requires 2 inputs.');
103
const tensorRank = inputs[0].dims.length;
104
if (tensorRank < 1) {
105
throw new Error('Invalid input shape.');
107
if (axis < -tensorRank || axis > tensorRank - 1) {
108
throw new Error('Invalid axis.');
110
if (NUMBER_TYPES.indexOf(inputs[0].type) === -1) {
111
throw new Error('Invaid input type.');
113
if (inputs[1].type !== 'int32' && inputs[1].type !== 'int16') {
114
throw new Error('Invaid input type.');