1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { DataType } from '../../../wasm-common';
5
import { MAX_CLIP, MIN_CLIP } from '../../util';
6
import { ProgramUniform } from '../types';
8
import { UniformsArrayType } from './common';
10
export interface InternalActivationAttributes {
11
readonly activation: string;
12
readonly clipMin?: number;
13
readonly clipMax?: number;
14
readonly alpha?: number;
15
readonly beta?: number;
18
export const getActivationSnippet = (
19
attributes: InternalActivationAttributes,
23
switch (attributes.activation) {
25
return `value = max(value, ${valueType}(0.0));`;
27
return `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));`;
29
return `value = clamp(value, ${valueType}(${baseType}(uniforms.clip_min)), ${valueType}(${
31
}(uniforms.clip_max)));`;
33
return `value = max(${valueType}(0.0), min(${valueType}(1.0), ${baseType}(uniforms.alpha) * value + ${
37
return `value = select(${baseType}(uniforms.alpha) * value, value, value >= ${valueType}(0.0));`;
39
return `let e2x = exp(-2.0 * abs(value));
40
value = sign(value) * (1.0 - e2x) / (1.0 + e2x);
44
// TODO: adding other activations that can be fused.
46
throw new Error(`Unsupported activation ${attributes.activation}`);
50
export const appendActivationUniformsData = (
51
attributes: InternalActivationAttributes,
52
programUniform: ProgramUniform[],
54
if (attributes.activation === 'Clip') {
56
{ type: DataType.float, data: attributes.clipMax! },
57
{ type: DataType.float, data: attributes.clipMin! },
59
} else if (attributes.activation === 'HardSigmoid') {
61
{ type: DataType.float, data: attributes.alpha! },
62
{ type: DataType.float, data: attributes.beta! },
64
} else if (attributes.activation === 'LeakyRelu') {
65
programUniform.push({ type: DataType.float, data: attributes.alpha! });
69
export const appendActivationUniforms = (attributes: InternalActivationAttributes, uniforms: UniformsArrayType) => {
70
if (attributes.activation === 'Clip') {
71
uniforms.push({ name: 'clip_max', type: 'f32' }, { name: 'clip_min', type: 'f32' });
72
} else if (attributes.activation === 'HardSigmoid') {
73
uniforms.push({ name: 'alpha', type: 'f32' }, { name: 'beta', type: 'f32' });
74
} else if (attributes.activation === 'LeakyRelu') {
75
uniforms.push({ name: 'alpha', type: 'f32' });
79
export const parseInternalActivationAttributes = (
80
attributes: Record<string, unknown> | undefined,
81
): InternalActivationAttributes => {
82
const activation = (attributes?.activation as string) || '';
83
if (activation === 'HardSigmoid') {
84
const [alpha, beta] = (attributes?.activation_params as [number, number]) || [0.2, 0.5];
85
return { activation, alpha, beta };
86
} else if (activation === 'Clip') {
87
const [clipMin, clipMax] = (attributes?.activation_params as [number, number]) || [MIN_CLIP, MAX_CLIP];
88
return { activation, clipMax, clipMin };
89
} else if (activation === 'LeakyRelu') {
90
const [alpha] = (attributes?.activation_params as [number]) || [0.01];
91
return { activation, alpha };
93
return { activation };