1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { ComputeContext, ProgramInfo, ProgramUniform } from '../types';
16
tensorTypeToWsglStorageType,
20
export interface SkipLayerNormAttributes {
25
const validateInputs = (inputs: readonly TensorView[]): void => {
26
if (!inputs || inputs.length < 3) {
27
throw new Error('layerNorm requires at least 3 inputs.');
30
const input: TensorView = inputs[0];
31
const skip: TensorView = inputs[1];
32
const gamma: TensorView = inputs[2];
34
if (input.dataType !== skip.dataType || input.dataType !== gamma.dataType) {
35
throw new Error('All inputs must have the same data type');
38
if (input.dims.length !== 3 && input.dims.length !== 2) {
39
throw new Error('Input must be 2D or 3D');
42
if (skip.dims.length !== 3 && skip.dims.length !== 2) {
43
throw new Error('Skip must be 2D or 3D');
46
const hiddenSize = input.dims[input.dims.length - 1];
47
const sequenceLength = input.dims[input.dims.length - 2];
48
if (skip.dims[skip.dims.length - 1] !== hiddenSize) {
49
throw new Error('Skip must have the same hidden size as input');
51
if (skip.dims[skip.dims.length - 2] !== sequenceLength) {
52
throw new Error('Skip must have the same sequence length as input');
55
if (gamma.dims.length !== 1) {
56
throw new Error('Gamma must be 1D');
58
if (gamma.dims[gamma.dims.length - 1] !== hiddenSize) {
59
throw new Error('Gamma must have the same hidden size as input');
61
if (inputs.length > 3) {
62
const beta: TensorView = inputs[3];
63
if (beta.dims.length !== 1) {
64
throw new Error('Beta must be 1D');
66
if (beta.dims[beta.dims.length - 1] !== hiddenSize) {
67
throw new Error('Beta must have the same hidden size as input');
70
if (inputs.length > 4) {
71
const bias: TensorView = inputs[4];
72
if (bias.dims.length !== 1) {
73
throw new Error('Bias must be 1D');
75
if (bias.dims[bias.dims.length - 1] !== hiddenSize) {
76
throw new Error('Bias must have the same hidden size as input');
81
const createSkipLayerNormProgramInfo = (
82
inputs: readonly TensorView[],
83
attributes: SkipLayerNormAttributes,
87
const simplified = attributes.simplified;
89
const inputShape = inputs[0].dims;
90
const inputSize = ShapeUtil.size(inputShape);
91
const outputShape = inputShape;
92
const outputSize = inputSize;
93
const hiddenSize = inputShape.slice(-1)[0];
94
const meanInvStdDevDim = isTraining ? inputShape.slice(0, -1).concat(1) : [];
95
const hasBetaInput = !simplified && inputs.length > 3;
96
const hasBiasInput = inputs.length > 4;
97
const hasMeanOutput = isTraining && outputCount > 1;
98
const hasInvStdDevOutput = isTraining && outputCount > 2;
99
const hasInputSkipBiasSumOutput = outputCount > 3;
100
const workgroupSize = 64;
102
const components = getMaxComponents(hiddenSize);
104
const programUniforms: ProgramUniform[] = [
105
{ type: DataType.uint32, data: outputSize },
106
{ type: DataType.uint32, data: components },
107
{ type: DataType.uint32, data: hiddenSize },
108
{ type: DataType.float, data: attributes.epsilon },
110
const getShaderSource = (shaderHelper: ShaderHelper) => {
111
const uniformsArray: UniformsArrayType = [
112
{ name: 'output_size', type: 'u32' },
113
{ name: 'components', type: 'u32' },
114
{ name: 'hidden_size', type: 'u32' },
115
{ name: 'epsilon', type: 'f32' },
118
inputVariable('x', inputs[0].dataType, inputs[0].dims, components),
119
inputVariable('skip', inputs[1].dataType, inputs[1].dims, components),
120
inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components),
123
variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components));
126
variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components));
128
variables.push(outputVariable('output', inputs[0].dataType, outputShape, components));
130
variables.push(outputVariable('mean_output', DataType.float, meanInvStdDevDim));
132
if (hasInvStdDevOutput) {
133
variables.push(outputVariable('inv_std_output', DataType.float, meanInvStdDevDim));
135
if (hasInputSkipBiasSumOutput) {
136
variables.push(outputVariable('input_skip_bias_sum', inputs[0].dataType, outputShape, components));
138
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
139
const vecDataType = tensorTypeToWsglStorageType(DataType.float, components);
142
${shaderHelper.registerUniforms(uniformsArray).declareVariables(...variables)}
143
var<workgroup> sum_shared : array<${vecDataType}, ${workgroupSize}>;
144
var<workgroup> sum_squared_shared : array<${vecDataType}, ${workgroupSize}>;
146
${shaderHelper.mainStart([workgroupSize, 1, 1])}
148
let iy = global_id.x / ${workgroupSize};
150
let hidden_size_vectorized: u32 = uniforms.hidden_size / uniforms.components;
151
var stride = hidden_size_vectorized / ${workgroupSize};
152
let offset = ix * stride + iy * hidden_size_vectorized;
153
let offset1d = stride * ix;
154
if (ix == ${workgroupSize - 1}) {
155
stride = hidden_size_vectorized - stride * ix;
157
for (var i: u32 = 0; i < stride; i++) {
158
let skip_value = skip[offset + i];
159
let bias_value = ${hasBiasInput ? 'bias[offset1d + i]' : dataType + '(0.0)'};
160
let input_value = x[offset + i];
161
let value = input_value + skip_value + bias_value;
162
${hasInputSkipBiasSumOutput ? 'input_skip_bias_sum[offset + i] = value;' : ''}
163
output[offset + i] = value;
164
let f32_value = ${castToF32(dataType, components, 'value')};
165
sum_shared[ix] += f32_value;
166
sum_squared_shared[ix] += f32_value * f32_value;
170
var reduce_size : u32 = ${workgroupSize};
171
for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {
172
reduce_size = curr_size + (reduce_size & 1);
173
if (ix < curr_size) {
174
sum_shared[ix] += sum_shared[ix + reduce_size];
175
sum_squared_shared[ix] += sum_squared_shared[ix + reduce_size];
180
let sum = sum_shared[0];
181
let square_sum = sum_squared_shared[0];
182
let mean = ${sumVector('sum', components)} / f32(uniforms.hidden_size);
183
let inv_std_dev = inverseSqrt(${sumVector('square_sum', components)} / f32(uniforms.hidden_size) ${
184
simplified ? '' : '- mean * mean'
185
} + uniforms.epsilon);
186
${hasMeanOutput ? 'mean_output[global_idx] = mean;' : ''}
187
${hasInvStdDevOutput ? 'inv_std_output[global_idx] = inv_std_dev;' : ''}
189
for (var i: u32 = 0; i < stride; i++) {
190
output[offset + i] = (output[offset + i] ${simplified ? '' : `- ${dataType}(mean)`}) *
191
${dataType}(inv_std_dev) * gamma[offset1d + i]
192
${hasBetaInput ? '+ beta[offset1d + i]' : ''};
196
const outputs = [{ dims: outputShape, dataType: inputs[0].dataType }];
197
if (outputCount > 1) {
198
outputs.push({ dims: meanInvStdDevDim, dataType: DataType.float });
200
if (outputCount > 2) {
201
outputs.push({ dims: meanInvStdDevDim, dataType: DataType.float });
203
if (outputCount > 3) {
204
outputs.push({ dims: inputShape, dataType: inputs[0].dataType });
207
name: 'SkipLayerNormalization',
209
hint: `${components};${hasMeanOutput};${hasInvStdDevOutput};${hasInputSkipBiasSumOutput}`,
210
inputDependencies: inputs.map((_input, _index) => 'type'),
216
x: Math.ceil(outputSize / hiddenSize),
223
export const skipLayerNorm = (context: ComputeContext, attributes: SkipLayerNormAttributes): void => {
224
// TODO: initialize isTraining from ComputeContext
225
const isTraining = false;
226
validateInputs(context.inputs);
227
// Mean and InvStdDev are only used in training mode and are not required for inference.
228
// They are added here for completeness only.
230
if (context.outputCount > 1) {
231
outputs.push(isTraining ? 1 : -3);
233
if (context.outputCount > 2) {
234
outputs.push(isTraining ? 2 : -3);
236
if (context.outputCount > 3) {
239
context.compute(createSkipLayerNormProgramInfo(context.inputs, attributes, context.outputCount, isTraining), {