1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { Tensor } from '../../../tensor';
5
import { BroadcastUtil, ShapeUtil } from '../../../util';
6
import { FunctionType, GlslValueFunction } from '../glsl-definitions';
7
import { getGlsl } from '../glsl-source';
8
import { WebGLInferenceHandler } from '../inference-handler';
9
import { ProgramInfo, ProgramInfoLoader, TextureType } from '../types';
11
export function glslAdd(): GlslValueFunction {
14
float ${name}(float a, float b) {
17
vec4 ${name}(vec4 v1, vec4 v2) {
21
return { body, name, type: FunctionType.ValueBased };
23
export function glslDiv(): GlslValueFunction {
26
float ${name}(float a, float b) {
29
vec4 ${name}(vec4 v1, vec4 v2) {
33
return { body, name, type: FunctionType.ValueBased };
35
export function glslMul(): GlslValueFunction {
38
float ${name}(float a, float b) {
41
vec4 ${name}(vec4 v1, vec4 v2) {
45
return { body, name, type: FunctionType.ValueBased };
47
export function glslSub(): GlslValueFunction {
50
float ${name}(float a, float b) {
53
vec4 ${name}(vec4 v1, vec4 v2) {
57
return { body, name, type: FunctionType.ValueBased };
59
export function glslEqual(): GlslValueFunction {
60
const name = 'equal_';
62
float ${name}(float a, float b) {
65
vec4 ${name}(vec4 v1, vec4 v2) {
66
return vec4(equal(v1, v2));
69
return { body, name, type: FunctionType.ValueBased };
71
export function glslGreater(): GlslValueFunction {
72
const name = 'greater_';
74
float ${name}(float a, float b) {
77
vec4 ${name}(vec4 v1, vec4 v2) {
78
return vec4( v1.r > v2.r ,
84
return { body, name, type: FunctionType.ValueBased };
86
export function glslLess(): GlslValueFunction {
89
float ${name}(float a, float b) {
92
vec4 ${name}(vec4 v1, vec4 v2) {
93
return vec4( v1.r < v2.r ,
99
return { body, name, type: FunctionType.ValueBased };
101
export function glslAnd(): GlslValueFunction {
104
float ${name}(float a, float b) {
105
return float( bool(a) && bool(b) );
107
vec4 ${name}(vec4 v1, vec4 v2) {
108
bvec4 b1 = bvec4(v1);
109
bvec4 b2 = bvec4(v2);
110
return vec4( b1.r && b2.r ,
116
return { body, name, type: FunctionType.ValueBased };
118
export function glslOr(): GlslValueFunction {
121
float ${name}(float a, float b) {
122
return float( bool(a) || bool(b) );
124
vec4 ${name}(vec4 v1, vec4 v2) {
125
bvec4 b1 = bvec4(v1);
126
bvec4 b2 = bvec4(v2);
127
return vec4( b1.r || b2.r ,
133
return { body, name, type: FunctionType.ValueBased };
135
export function glslXor(): GlslValueFunction {
138
float ${name}(float a, float b) {
139
return float( bool(a) ^^ bool(b) );
141
vec4 ${name}(vec4 v1, vec4 v2) {
142
bvec4 b1 = bvec4(v1);
143
bvec4 b2 = bvec4(v2);
144
return vec4( b1.r ^^ b2.r ,
150
return { body, name, type: FunctionType.ValueBased };
152
export function glslPow(): GlslValueFunction {
153
return glslBuiltinBinary('pow');
155
export function glslPRelu(): GlslValueFunction {
156
const name = 'prelu_';
158
float ${name}(float a, float b) {
159
return a < 0.0 ? a * b: a;
161
vec4 ${name}(vec4 v1, vec4 v2) {
163
v1.r < 0.0 ? v1.r * v2.r: v1.r,
164
v1.g < 0.0 ? v1.g * v2.g: v1.g,
165
v1.b < 0.0 ? v1.b * v2.b: v1.b,
166
v1.a < 0.0 ? v1.a * v2.a: v1.a
170
return { body, name, type: FunctionType.ValueBased };
173
function glslBuiltinBinary(fname: string): GlslValueFunction {
174
const name = `${fname}_`;
176
float ${name}(float a, float b) {
177
return ${fname}(a, b);
179
vec4 ${name}(vec4 v1, vec4 v2) {
180
return ${fname}(v1, v2);
183
return { body, name, type: FunctionType.ValueBased };
186
const createBinaryProgramInfoLoader = (
187
handler: WebGLInferenceHandler,
189
glslFunc: GlslValueFunction,
190
outputTensorType: Tensor.DataType = inputs[0].type,
192
): ProgramInfoLoader => {
193
const textureType = handler.session.pack ? TextureType.packed : TextureType.unpacked;
196
inputNames: ['A', 'B'],
197
inputTypes: [textureType, textureType],
199
get: () => createBinaryProgramInfo(handler, inputs, glslFunc, outputTensorType),
203
const createBinaryProgramInfo = (
204
handler: WebGLInferenceHandler,
206
glslFunc: GlslValueFunction,
207
outputTensorType: Tensor.DataType = inputs[0].type,
209
const textureType = handler.session.pack ? TextureType.packed : TextureType.unpacked;
210
const isBroadcast = !ShapeUtil.areEqual(inputs[0].dims, inputs[1].dims);
211
let outputShape = inputs[0].dims;
213
const usePackedTexture = handler.session.pack;
216
const calculatedShape = BroadcastUtil.calcShape(inputs[0].dims, inputs[1].dims, false);
217
if (!calculatedShape) {
218
throw new Error("Can't perform binary op on the given tensors");
220
outputShape = calculatedShape;
221
const outputRank = outputShape.length;
222
const aRank = inputs[0].dims.length !== 0 ? inputs[0].dims.length : 1;
223
const bRank = inputs[1].dims.length !== 0 ? inputs[1].dims.length : 1;
224
const aBcast = inputs[0].dims.length !== 0 ? 'bcastIndices_A(indices, aindices);' : 'aindices[0] = 0;';
225
const bBcast = inputs[1].dims.length !== 0 ? 'bcastIndices_B(indices, bindices);' : 'bindices[0] = 0;';
227
const glsl = getGlsl(handler.session.backend.glContext.version);
228
const shaderSource = usePackedTexture
232
vec4 a = getAAtOutCoords();
233
vec4 b = getBAtOutCoords();
234
vec4 result = ${glslFunc.name}(a, b);
235
${glsl.output} = result;
239
float process(int indices[${outputRank}]) {
240
int aindices[${aRank}];
241
int bindices[${bRank}];
244
return ${glslFunc.name}(_A(aindices), _B(bindices));
249
inputNames: ['A', 'B'],
250
inputTypes: [textureType, textureType],
251
output: { dims: outputShape, type: outputTensorType, textureType },
253
hasMain: usePackedTexture,
256
const glsl = getGlsl(handler.session.backend.glContext.version);
257
const shaderSource = `
260
vec4 v1 = ${glsl.texture2D}(A, TexCoords);
261
vec4 v2 = ${glsl.texture2D}(B, TexCoords);
262
vec4 result = ${glslFunc.name}(v1, v2);
263
${glsl.output} = result;
269
inputNames: ['A', 'B'],
270
inputTypes: [textureType, textureType],
271
output: { dims: inputs[0].dims, type: outputTensorType, textureType },
277
export const add = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
278
handler.run(createBinaryProgramInfoLoader(handler, inputs, glslAdd()), inputs),
281
export const and = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
282
handler.run(createBinaryProgramInfoLoader(handler, inputs, glslAnd(), 'bool'), inputs),
285
export const div = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
286
handler.run(createBinaryProgramInfoLoader(handler, inputs, glslDiv()), inputs),
289
export const equal = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
290
handler.run(createBinaryProgramInfoLoader(handler, inputs, glslEqual(), 'bool'), inputs),
293
export const greater = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
294
handler.run(createBinaryProgramInfoLoader(handler, inputs, glslGreater(), 'bool'), inputs),
297
export const less = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
298
handler.run(createBinaryProgramInfoLoader(handler, inputs, glslLess(), 'bool'), inputs),
301
export const mul = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
302
handler.run(createBinaryProgramInfoLoader(handler, inputs, glslMul()), inputs),
305
export const or = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
306
handler.run(createBinaryProgramInfoLoader(handler, inputs, glslOr(), 'bool'), inputs),
309
export const pow = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
310
handler.run(createBinaryProgramInfoLoader(handler, inputs, glslPow()), inputs),
313
export const pRelu = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
314
handler.run(createBinaryProgramInfoLoader(handler, inputs, glslPRelu()), inputs),
317
export const sub = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
318
handler.run(createBinaryProgramInfoLoader(handler, inputs, glslSub()), inputs),
321
export const xor = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
322
handler.run(createBinaryProgramInfoLoader(handler, inputs, glslXor(), 'bool'), inputs),