1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
8
import { ComputeContext, ProgramInfo, ProgramUniform } from '../types';
14
tensorTypeToWsglValueType,
15
UniformDataElementType,
19
type BuiltinFunctionName = string;
20
type ElementwiseCustomExpression = (expression: string) => string;
21
type ElementwiseFunctionCall = BuiltinFunctionName | ElementwiseCustomExpression;
23
const createElementwiseProgramShader = (
24
shaderHelper: ShaderHelper,
26
inputDataType: number,
27
outputDataType: number,
28
funcCall: ElementwiseFunctionCall,
29
additionalImplementation?: string,
30
additionalUniformsType?: UniformsArrayType,
32
const vecSize = Math.ceil(datasize / 4);
35
if (typeof funcCall === 'string') {
36
expression = `${funcCall}(a)`;
38
expression = funcCall('a');
41
const input = inputVariable('inputData', inputDataType, [vecSize], 4);
42
const output = outputVariable('outputData', outputDataType, [vecSize], 4);
43
const uniforms: UniformsArrayType = [{ name: 'vec_size', type: 'u32' }];
44
if (additionalUniformsType) {
45
uniforms.push(...additionalUniformsType);
49
${shaderHelper.registerUniforms(uniforms).declareVariables(input, output)}
51
${additionalImplementation ?? ''}
53
${shaderHelper.mainStart()}
54
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')}
56
let a = ${input.getByOffset('global_idx')};
57
${output.setByOffset('global_idx', expression)}
61
const createElementwiseProgramInfo = (
64
funcCall: ElementwiseFunctionCall,
65
additionalImplementation?: string,
67
outputDataType: number = input.dataType,
68
additionalUniforms?: ProgramUniform[],
69
additionalUniformsType?: UniformsArrayType,
71
const programUniforms: ProgramUniform[] = [
72
{ type: DataType.uint32, data: Math.ceil(ShapeUtil.size(input.dims) / 4) },
74
if (additionalUniforms) {
75
programUniforms.push(...additionalUniforms);
80
shaderCache: { hint: cacheKey, inputDependencies: ['type'] },
81
getShaderSource: (shaderHelper) =>
82
createElementwiseProgramShader(
84
ShapeUtil.size(input.dims),
88
additionalImplementation,
89
additionalUniformsType,
91
getRunData: (inputTensors) => ({
92
outputs: [{ dims: input.dims, dataType: outputDataType }],
94
x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */),
101
export const abs = (context: ComputeContext): void => {
102
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Abs', 'abs'));
105
export const acos = (context: ComputeContext): void => {
106
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Acos', 'acos'));
109
export const acosh = (context: ComputeContext): void => {
110
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Acosh', 'acosh'));
113
export const asin = (context: ComputeContext): void => {
114
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Asin', 'asin'));
117
export const asinh = (context: ComputeContext): void => {
118
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Asinh', 'asinh'));
121
export const atan = (context: ComputeContext): void => {
122
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Atan', 'atan'));
124
export const atanh = (context: ComputeContext): void => {
125
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Atanh', 'atanh'));
128
export interface CastAttributes extends AttributeWithCacheKey {
130
readonly saturate?: boolean;
133
export const parseCastAttributes = (attributes: Record<string, unknown>): CastAttributes =>
134
createAttributeWithCacheKey(attributes as { to: number });
136
export const cast = (context: ComputeContext, attributes: CastAttributes): void => {
137
let func: ElementwiseFunctionCall;
138
switch (attributes.to) {
139
case DataType.float16:
145
case DataType.uint32:
155
throw new RangeError(`not supported type (specified in attribute 'to' from 'Cast' operator): ${attributes.to}`);
158
createElementwiseProgramInfo(context.inputs[0], 'Cast', func, undefined, attributes.cacheKey, attributes.to),
162
export interface ClipAttributes extends AttributeWithCacheKey {
163
readonly min: number;
164
readonly max: number;
167
const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => {
170
const hasMin = inputs.length >= 2 && inputs[1].data !== 0;
171
const hasMax = inputs.length >= 3 && inputs[2].data !== 0;
173
switch (inputs[0].dataType) {
175
min = hasMin ? inputs[1].getFloat32Array()[0] : -3.4028234663852886e38;
176
max = hasMax ? inputs[2].getFloat32Array()[0] : 3.4028234663852886e38;
178
case DataType.float16:
179
min = hasMin ? inputs[1].getUint16Array()[0] : 64511; // uint16(64511) <-> float16(-65504.0)
180
max = hasMax ? inputs[2].getUint16Array()[0] : 31743; // uint16(31743) <-> float16(65504.0)
183
throw new Error('Unsupport data type');
186
return createAttributeWithCacheKey({ min, max });
189
export const clip = (context: ComputeContext, clipAttributes: ClipAttributes): void => {
190
const attributes = clipAttributes ? clipAttributes : generateClipAttributesFromInputs(context.inputs);
191
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
193
createElementwiseProgramInfo(
196
(a) => `clamp(${a}, vec4<${dataType}>(uniforms.min), vec4<${dataType}>(uniforms.max))`,
201
{ type: context.inputs[0].dataType, data: attributes.min },
202
{ type: context.inputs[0].dataType, data: attributes.max },
205
{ name: 'min', type: dataType as UniformDataElementType },
206
{ name: 'max', type: dataType as UniformDataElementType },
213
export const ceil = (context: ComputeContext): void => {
214
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Ceil', 'ceil'));
217
export const cos = (context: ComputeContext): void => {
218
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Cos', 'cos'));
221
export const cosh = (context: ComputeContext): void => {
222
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Cosh', 'cosh'));
225
export interface AlphaAttributes extends AttributeWithCacheKey {
226
readonly alpha: number;
229
export const parseAlphaAttributes = (attributes: Record<string, unknown>): AlphaAttributes =>
230
createAttributeWithCacheKey(attributes as { alpha: number });
232
export const elu = (context: ComputeContext, attributes: AlphaAttributes): void => {
233
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
235
createElementwiseProgramInfo(
238
(a) => `elu_vf32(${a})`,
240
const elu_alpha_ = ${dataType}(${attributes.alpha});
242
fn elu_f32(a: ${dataType}) -> ${dataType} {
243
return select((exp(a) - 1.0) * elu_alpha_, a, a >= 0.0);
246
fn elu_vf32(v: vec4<${dataType}>) -> vec4<${dataType}> {
247
return vec4(elu_f32(v.x), elu_f32(v.y), elu_f32(v.z), elu_f32(v.w));
254
export const erfImpl = (varType = 'f32') => `
255
const r0: ${varType} = 0.3275911;
256
const r1: ${varType} = 0.254829592;
257
const r2: ${varType} = -0.284496736;
258
const r3: ${varType} = 1.421413741;
259
const r4: ${varType} = -1.453152027;
260
const r5: ${varType} = 1.061405429;
262
fn erf_vf32(v: vec4<${varType}>) -> vec4<${varType}> {
264
let x = 1.0 / (1.0 + r0 * absv);
265
return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv));
268
export const erf = (context: ComputeContext): void => {
269
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
270
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Erf', (a) => `erf_vf32(${a})`, erfImpl(dataType)));
273
export const exp = (context: ComputeContext): void => {
274
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Exp', 'exp'));
277
export const floor = (context: ComputeContext): void => {
278
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Floor', 'floor'));
281
export const gelu = (context: ComputeContext): void => {
282
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
284
createElementwiseProgramInfo(
287
(a) => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`,
293
export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => {
294
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
296
createElementwiseProgramInfo(
299
(a) => `select(leaky_relu_alpha_ * ${a}, ${a}, ${a} >= vec4<${dataType}>(0.0))`,
300
`const leaky_relu_alpha_ = ${dataType}(${attributes.alpha});`,
306
export const not = (context: ComputeContext): void => {
307
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Not', (a) => `!${a}`));
310
export const neg = (context: ComputeContext): void => {
311
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Neg', (a) => `-${a}`));
314
export const reciprocal = (context: ComputeContext): void => {
315
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Reciprocal', (a) => `1.0/${a}`));
318
export const relu = (context: ComputeContext): void => {
319
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
321
createElementwiseProgramInfo(
324
(a) => `select(vec4<${dataType}>(0.0), ${a}, ${a} > vec4<${dataType}>(0.0))`,
329
export const sigmoid = (context: ComputeContext): void => {
330
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sigmoid', (a) => `(1.0 / (1.0 + exp(-${a})))`));
333
export interface HardSigmoidAttributes extends AttributeWithCacheKey {
334
readonly alpha: number;
335
readonly beta: number;
338
export const parseHardSigmoidAttributes = (attributes: Record<string, unknown>): HardSigmoidAttributes =>
339
createAttributeWithCacheKey(
346
export const hardSigmoid = (context: ComputeContext, attributes: HardSigmoidAttributes): void => {
347
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
349
createElementwiseProgramInfo(
353
`max(vec4<${dataType}>(0.0), min(vec4<${dataType}>(1.0), ${attributes.alpha} * ${a} + vec4<${dataType}>(${attributes.beta})))`,
360
export const sin = (context: ComputeContext): void => {
361
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sin', 'sin'));
364
export const sinh = (context: ComputeContext): void => {
365
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sinh', 'sinh'));
368
export const sqrt = (context: ComputeContext): void => {
369
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sqrt', 'sqrt'));
372
export const tan = (context: ComputeContext): void => {
373
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tan', 'tan'));
376
export const tanhExpression = (a: string) => `sign(${a}) * (1 - exp(-2 * abs(${a}))) / (1 + exp(-2 * abs(${a})))`;
378
export const tanh = (context: ComputeContext): void => {
379
// TODO: revisit after https://github.com/gpuweb/gpuweb/issues/4458 is resolved
380
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tanh', tanhExpression));
383
export const fastGeluImpl = (varType = 'f32') => `
384
const fast_gelu_a: ${varType} = 0.5;
385
const fast_gelu_b: ${varType} = 0.7978845608028654;
386
const fast_gelu_c: ${varType} = 0.035677408136300125;
388
fn tanh_v(v: vec4<${varType}>) -> vec4<${varType}> {
389
return ${tanhExpression('v')};
393
export const fastGeluExpression = (x: string) =>
394
`(fast_gelu_a + fast_gelu_a * tanh_v(${x} * (fast_gelu_c * ${x} * ${x} + fast_gelu_b))) * ${x}`;
396
export const fastGelu = (context: ComputeContext): void => {
397
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
399
createElementwiseProgramInfo(
403
fastGeluImpl(dataType),
405
context.inputs[0].dataType,
410
export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttributes): number => {
411
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
413
createElementwiseProgramInfo(
416
(a) => `select(vec4<${dataType}>(0.0), ${a}, ${a} > thresholded_relu_alpha_)`,
417
`const thresholded_relu_alpha_ = vec4<${dataType}>(${attributes.alpha});`,
424
export const log = (context: ComputeContext): void => {
425
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Log', 'log'));
428
export const quickGeluImpl = (varType: string, alpha: number) => `
429
const alpha = vec4<${varType}>(${alpha});
430
const one = ${varType}(1.0);
431
const zero = ${varType}(0.0);
433
fn quick_gelu_impl(x: vec4<${varType}>) -> vec4<${varType}> {
435
var x1 : vec4<${varType}>;
436
for (var i = 0; i < 4; i = i + 1) {
438
x1[i] = one / (one + exp(-v[i]));
440
x1[i] = one - one / (one + exp(v[i]));
447
export const quickGeluExpression = (x: string) => `quick_gelu_impl(${x})`;
449
export const quickgelu = (context: ComputeContext, attributes: AlphaAttributes): void => {
450
const dType = tensorTypeToWsglValueType(context.inputs[0].dataType);
452
createElementwiseProgramInfo(
456
quickGeluImpl(dType, attributes.alpha),
458
context.inputs[0].dataType,