4
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
5
import { Graph } from '../../../graph';
6
import { NUMBER_TYPES, OperatorImplementation, OperatorInitialization } from '../../../operators';
7
import { Tensor } from '../../../tensor';
8
import { ShapeUtil } from '../../../util';
9
import { WebGLInferenceHandler } from '../inference-handler';
10
import { ProgramInfo, ProgramMetadata, TextureType } from '../types';
12
export interface ReduceAttributes extends AttributeWithCacheKey {
13
readonly axes: number[];
14
readonly keepDims: boolean;
18
type ReduceOp = (inputs: Tensor[], axes: number[]) => string[];
21
inferenceHandler: WebGLInferenceHandler,
23
attributes: ReduceAttributes,
27
validateInputs(inputs);
29
const reduceProgramMetadata = {
32
inputTypes: [TextureType.unpacked],
35
const output = inferenceHandler.run(
37
...reduceProgramMetadata,
38
cacheHint: attributes.cacheKey,
39
get: () => createReduceProgramInfo(inferenceHandler, inputs, attributes, name, reduceOp, reduceProgramMetadata),
46
export const parseReduceAttributes: OperatorInitialization<ReduceAttributes> = (node: Graph.Node): ReduceAttributes => {
47
const axes = node.attributes.getInts('axes', []);
48
const keepDims = node.attributes.getInt('keepdims', 1) === 1;
49
return createAttributeWithCacheKey({ axes, keepDims });
52
const createReduceProgramInfo = (
53
_handler: WebGLInferenceHandler,
55
attributes: ReduceAttributes,
58
reduceProgramMetadata: ProgramMetadata,
60
const outputShape: number[] = [];
61
const iRank = inputs[0].dims.length || 1;
65
const axes = ShapeUtil.normalizeAxes(attributes.axes, inputs[0].dims.length);
66
const ops = reduceOp(inputs, axes);
67
let reduceOps = ops[1];
69
for (let k = 0; k < inputs[0].dims.length; k++) {
71
if (axes.indexOf(k) >= 0 || axes.length === 0) {
72
if (attributes.keepDims) {
78
for(int j${k} = 0; j${k} < ${inputs[0].dims[k]}; j${k}++) {
79
inputIdx[${k}] = j${k};
83
idxCopy.push(`inputIdx[${k}] = outputIdx[${outputShape.length}];`);
85
outputShape.push(inputs[0].dims[k]);
89
const oRank = outputShape.length || 1;
91
const shaderSource = `
92
float process(int outputIdx[${oRank}]) {
93
float value; // final result
94
int inputIdx[${iRank}]; // addressing input data
96
${ops[0]} // init ops for reduce max/min
98
${ops[2]} // final computation for reduce mean
103
...reduceProgramMetadata,
104
output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked },
109
const validateInputs = (inputs: Tensor[]): void => {
111
if (!inputs || inputs.length !== 1) {
112
throw new Error('Reduce op requires 1 input.');
115
if (NUMBER_TYPES.indexOf(inputs[0].type) === -1) {
116
throw new Error('Invalid input type.');
120
export const reduceSum: OperatorImplementation<ReduceAttributes> = (
121
inferenceHandler: WebGLInferenceHandler,
123
attributes: ReduceAttributes,
125
const reduceOp: ReduceOp = (): string[] => ['value = 0.0;', 'value += _A(inputIdx);', ''];
126
return reduce(inferenceHandler, inputs, attributes, 'ReduceSum', reduceOp);
129
export const reduceMean: OperatorImplementation<ReduceAttributes> = (
130
inferenceHandler: WebGLInferenceHandler,
132
attributes: ReduceAttributes,
134
const reduceOp: ReduceOp = (inputs: Tensor[], axes: number[]): string[] => {
136
for (let k = 0; k < inputs[0].dims.length; k++) {
137
if (axes.indexOf(k) >= 0 || axes.length === 0) {
138
size *= inputs[0].dims[k];
142
return ['value = 0.0;', 'value += _A(inputIdx);', `value /= ${size}.;`];
144
return reduce(inferenceHandler, inputs, attributes, 'ReduceMean', reduceOp);
147
export const reduceMax: OperatorImplementation<ReduceAttributes> = (
148
inferenceHandler: WebGLInferenceHandler,
150
attributes: ReduceAttributes,
152
const reduceOp: ReduceOp = (inputs: Tensor[], axes: number[]): string[] => {
154
for (let k = 0; k < inputs[0].dims.length; k++) {
155
if (axes.indexOf(k) >= 0 || axes.length === 0) {
156
idxZero.push(`inputIdx[${k}] = 0;`);
160
return [`${idxZero.join('\n')}\nvalue = _A(inputIdx);`, 'value = max(value, _A(inputIdx));', ''];
162
return reduce(inferenceHandler, inputs, attributes, 'ReduceMax', reduceOp);
165
export const reduceMin: OperatorImplementation<ReduceAttributes> = (
166
inferenceHandler: WebGLInferenceHandler,
168
attributes: ReduceAttributes,
170
const reduceOp: ReduceOp = (inputs: Tensor[], axes: number[]): string[] => {
172
for (let k = 0; k < inputs[0].dims.length; k++) {
173
if (axes.indexOf(k) >= 0 || axes.length === 0) {
174
idxZero.push(`inputIdx[${k}] = 0;`);
178
return [`${idxZero.join('\n')}\nvalue = _A(inputIdx);`, 'value = min(value, _A(inputIdx));', ''];
180
return reduce(inferenceHandler, inputs, attributes, 'ReduceMin', reduceOp);
183
export const reduceProd: OperatorImplementation<ReduceAttributes> = (
184
inferenceHandler: WebGLInferenceHandler,
186
attributes: ReduceAttributes,
188
const reduceOp: ReduceOp = (): string[] => ['value = 1.0;', 'value *= _A(inputIdx);', ''];
189
return reduce(inferenceHandler, inputs, attributes, 'ReduceProd', reduceOp);
192
export const reduceLogSum: OperatorImplementation<ReduceAttributes> = (
193
inferenceHandler: WebGLInferenceHandler,
195
attributes: ReduceAttributes,
197
const reduceOp: ReduceOp = (): string[] => ['value = 0.0;', 'value += _A(inputIdx);', 'value = log(value);'];
198
return reduce(inferenceHandler, inputs, attributes, 'ReduceLogSum', reduceOp);
201
export const reduceLogSumSquare: OperatorImplementation<ReduceAttributes> = (
202
inferenceHandler: WebGLInferenceHandler,
204
attributes: ReduceAttributes,
206
const reduceOp: ReduceOp = (): string[] => ['float t; value = 0.0;', 't = _A(inputIdx); value += t * t;', ''];
207
return reduce(inferenceHandler, inputs, attributes, 'ReduceLogSumSquare', reduceOp);