1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { Tensor } from '../../../tensor';
5
import { getGlsl } from '../glsl-source';
6
import { WebGLInferenceHandler } from '../inference-handler';
7
import { ProgramInfo, ProgramMetadata, TextureType } from '../types';
9
export const sum = (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => {
10
validateInputs(inputs);
12
const sumProgramMetadata = {
14
inputNames: inputs.map((_v, i) => `X${i}`),
15
inputTypes: new Array(inputs.length).fill(TextureType.unpacked),
18
const output = inferenceHandler.run(
19
{ ...sumProgramMetadata, get: () => createSumProgramInfo(inferenceHandler, inputs, sumProgramMetadata) },
25
const createSumProgramInfo = (
26
inferenceHandler: WebGLInferenceHandler,
28
sumProgramMetadata: ProgramMetadata,
30
const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
31
const outputShape = inputs[0].dims.slice();
32
const sumLine = inputs.map((_v, i) => `${glsl.texture2D}(X${i},TexCoords)`).join(' + ');
33
const shaderSource = `
35
vec4 result = ${sumLine};
36
${glsl.output} = result;
40
...sumProgramMetadata,
41
output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked },
47
const validateInputs = (inputs: Tensor[]): void => {
48
if (!inputs || inputs.length === 0) {
49
throw new Error('Sum requires inputs.');
52
const length = inputs[0].dims.length;
53
for (let i = 1; i < inputs.length; i++) {
54
if (length !== inputs[i].dims.length) {
55
throw new Error('Input shapes are mismatched.');
58
for (let j = 0; j < length; j++) {
59
if (inputs[0].dims[j] !== inputs[i].dims[j]) {
60
throw new Error('Input shapes are not matched.');
65
if (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') {
66
throw new Error('Invalid input type.');
68
for (let i = 1; i < inputs.length; i++) {
69
if (inputs[0].type !== inputs[i].type) {
70
throw new Error('Input types are not matched.');