1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
5
import { Graph } from '../../../graph';
6
import { Tensor } from '../../../tensor';
7
import { MAX_CLIP, MIN_CLIP } from '../../../util';
8
import { FunctionType, GlslValueFunction } from '../glsl-definitions';
9
import { getGlsl } from '../glsl-source';
10
import { WebGLInferenceHandler } from '../inference-handler';
11
import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types';
13
export function glslAbs(): GlslValueFunction {
14
return glslBuiltinUnary('abs');
16
export function glslAcos(): GlslValueFunction {
17
return glslBuiltinUnary('acos');
19
export function glslAsin(): GlslValueFunction {
20
return glslBuiltinUnary('asin');
22
export function glslAtan(): GlslValueFunction {
23
return glslBuiltinUnary('atan');
25
export function glslCeil(): GlslValueFunction {
26
return glslBuiltinUnary('ceil');
28
export function glslCos(): GlslValueFunction {
29
return glslBuiltinUnary('cos');
31
export function glslElu(alpha: number): GlslValueFunction {
34
const float alpha = float(${alpha});
36
float ${name}_(float a) {
37
return a >= 0.0 ? a: (exp(a) - 1.0) * alpha;
39
vec4 ${name}_(vec4 v) {
40
return vec4(${name}_(v.x), ${name}_(v.y), ${name}_(v.z), ${name}_(v.w));
43
return { body, name, type: FunctionType.ValueBased };
45
export function glslExp(): GlslValueFunction {
46
return glslBuiltinUnary('exp');
48
export function glslFloor(): GlslValueFunction {
49
return glslBuiltinUnary('floor');
51
export function glslClip(min: number, max: number): GlslValueFunction {
54
const float min = float(${min});
55
const float max = float(${max});
57
float ${name}_(float a) {
58
return clamp(a, min, max);
60
vec4 ${name}_(vec4 v) {
61
return clamp(v, min, max);
64
return { body, name, type: FunctionType.ValueBased };
66
export function glslIdentity(): GlslValueFunction {
67
const name = 'indentity';
69
float ${name}_(float a) {
72
vec4 ${name}_(vec4 v) {
76
return { body, name, type: FunctionType.ValueBased };
78
export function glslLeakyRelu(alpha: number): GlslValueFunction {
79
const name = 'leakyRelu';
81
const float alpha = float(${alpha});
83
float ${name}_(float a) {
84
return a < 0.0 ? a * alpha : a;
86
vec4 ${name}_(vec4 v) {
87
return vec4(${name}_(v.x), ${name}_(v.y), ${name}_(v.z), ${name}_(v.w));
90
return { body, name, type: FunctionType.ValueBased };
92
export function glslLog(): GlslValueFunction {
93
return glslBuiltinUnary('log');
95
export function glslNeg(): GlslValueFunction {
98
float ${name}_(float a) {
101
vec4 ${name}_(vec4 v) {
105
return { body, name, type: FunctionType.ValueBased };
107
export function glslNot(): GlslValueFunction {
110
float ${name}_(float a) {
111
return float( ! bool(a) );
113
bool ${name}_(bool a) {
116
vec4 ${name}_(vec4 v) {
117
return vec4(!bool(v.x), !bool(v.y), !bool(v.z), !bool(v.w));
119
bvec4 ${name}_(bvec4 v) {
120
return bvec4(!v.x, !v.y, !v.z, !v.w);
123
return { body, name, type: FunctionType.ValueBased };
125
export function glslSin(): GlslValueFunction {
126
return glslBuiltinUnary('sin');
128
export function glslRelu(): GlslValueFunction {
131
float ${name}_(float a) {
132
return max( a, 0.0 );
134
vec4 ${name}_(vec4 v) {
135
return max( v, 0.0 );
138
return { body, name, type: FunctionType.ValueBased };
140
export function glslSigmoid(): GlslValueFunction {
141
const name = 'sigmoid';
143
float ${name}_(float a) {
144
return 1.0 / (1.0 + exp(-a));
146
vec4 ${name}_(vec4 v) {
147
return 1.0 / (1.0 + exp(-v));
150
return { body, name, type: FunctionType.ValueBased };
152
export function glslSqrt(): GlslValueFunction {
153
return glslBuiltinUnary('sqrt');
155
export function glslTan(): GlslValueFunction {
156
return glslBuiltinUnary('tan');
158
export function glslTanh(): GlslValueFunction {
161
float ${name}_(float a) {
162
a = clamp(a, -10., 10.);
164
return (a - 1.) / (a + 1.);
166
vec4 ${name}_(vec4 v) {
167
v = clamp(v, -10., 10.);
169
return (v - 1.) / (v + 1.);
172
return { body, name, type: FunctionType.ValueBased };
174
function glslBuiltinUnary(name: string): GlslValueFunction {
176
float ${name}_(float a) {
179
vec4 ${name}_(vec4 v) {
183
return { body, name, type: FunctionType.ValueBased };
190
const createElementwiseProgramInfo = (
191
handler: WebGLInferenceHandler,
192
metadata: ProgramMetadata,
194
glslFunc: GlslValueFunction,
196
const textureType = handler.session.pack ? TextureType.packed : TextureType.unpacked;
197
const glsl = getGlsl(handler.session.backend.glContext.version);
200
output: { dims: input.dims, type: input.type, textureType },
204
vec4 v = ${glsl.texture2D}(A, TexCoords);
205
v = ${glslFunc.name}_(v);
213
const createElementwiseProgramInfoLoader = (
214
handler: WebGLInferenceHandler,
216
glslFunc: GlslValueFunction,
218
): ProgramInfoLoader => {
219
const textureType = handler.session.pack ? TextureType.packed : TextureType.unpacked;
220
const metadata = { name: glslFunc.name, inputTypes: [textureType], inputNames: ['A'], cacheHint: cacheKey };
221
return { ...metadata, get: () => createElementwiseProgramInfo(handler, metadata, input, glslFunc) };
224
export const abs = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
225
handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslAbs()), inputs),
228
export const acos = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
229
handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslAcos()), inputs),
232
export const asin = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
233
handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslAsin()), inputs),
236
export const atan = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
237
handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslAtan()), inputs),
240
export interface ClipAttributes extends AttributeWithCacheKey {
241
readonly min: number;
242
readonly max: number;
245
export const clip = (handler: WebGLInferenceHandler, inputs: Tensor[], attributes: ClipAttributes): Tensor[] => [
247
createElementwiseProgramInfoLoader(
250
glslClip(attributes.min, attributes.max),
257
export const parseClipAttributes = (node: Graph.Node): ClipAttributes =>
258
createAttributeWithCacheKey({
259
min: node.attributes.getFloat('min', MIN_CLIP),
260
max: node.attributes.getFloat('max', MAX_CLIP),
263
export const clipV11 = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => {
264
const attributes = generateClipAttributesFromInputs(handler, inputs);
265
return clip(handler, [inputs[0]], attributes);
268
const generateClipAttributesFromInputs = (handler: WebGLInferenceHandler, inputs: Tensor[]): ClipAttributes => {
270
inputs.length >= 3 &&
271
(!handler.session.isInitializer(inputs[1].dataId) || !handler.session.isInitializer(inputs[2].dataId))
273
throw new Error('dynamic clip attributes are not allowed');
276
const min = inputs.length >= 3 ? inputs[1].numberData[0] : MIN_CLIP;
277
const max = inputs.length >= 3 ? inputs[2].numberData[0] : MAX_CLIP;
278
return createAttributeWithCacheKey({ min, max });
281
export const ceil = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
282
handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslCeil()), inputs),
285
export const cos = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
286
handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslCos()), inputs),
289
export interface EluAttributes extends AttributeWithCacheKey {
290
readonly alpha: number;
293
export const elu = (handler: WebGLInferenceHandler, inputs: Tensor[], attributes: EluAttributes): Tensor[] => [
295
createElementwiseProgramInfoLoader(handler, inputs[0], glslElu(attributes.alpha), attributes.cacheKey),
300
export const parseEluAttributes = (node: Graph.Node): EluAttributes =>
301
createAttributeWithCacheKey({ alpha: node.attributes.getFloat('alpha', 1.0) });
303
export const exp = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
304
handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslExp()), inputs),
307
export const floor = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
308
handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslFloor()), inputs),
311
export const identity = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
312
handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslIdentity()), inputs),
315
export interface LeakyReluAttributes extends AttributeWithCacheKey {
316
readonly alpha: number;
319
export const leakyRelu = (
320
handler: WebGLInferenceHandler,
322
attributes: LeakyReluAttributes,
325
createElementwiseProgramInfoLoader(handler, inputs[0], glslLeakyRelu(attributes.alpha), attributes.cacheKey),
330
export const parseLeakyReluAttributes = (node: Graph.Node): LeakyReluAttributes =>
331
createAttributeWithCacheKey({ alpha: node.attributes.getFloat('alpha', 0.01) });
333
export const log = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
334
handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslLog()), inputs),
337
export const neg = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
338
handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslNeg()), inputs),
341
export const not = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
342
handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslNot()), inputs),
345
export const relu = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
346
handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslRelu()), inputs),
349
export const sigmoid = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
350
handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslSigmoid()), inputs),
353
export const sin = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
354
handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslSin()), inputs),
357
export const sqrt = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
358
handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslSqrt()), inputs),
361
export const tan = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
362
handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslTan()), inputs),
365
export const tanh = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
366
handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslTanh()), inputs),