1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { TensorView } from '../../tensor-view';
5
import { ShapeUtil } from '../../util';
6
import { ComputeContext, ProgramInfo } from '../types';
8
import { inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType } from './common';
9
import { erfImpl } from './unary-op';
11
const validateInputs = (inputs: readonly TensorView[]): void => {
12
if (inputs[0].dims.length !== 3) {
13
throw new Error('input should have 3 dimensions');
16
if (![2560, 5120, 10240].includes(inputs[0].dims[2])) {
17
throw new Error('hidden state should be 2560, 5120 or 10240');
20
if (inputs[1].dims.length !== 1) {
21
throw new Error('bias is expected to have 1 dimensions');
24
if (inputs[0].dims[2] !== inputs[1].dims[0]) {
25
throw new Error('last dimension of input and bias are not the same');
29
const createBiasSplitGeluProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => {
30
const outputShape = inputs[0].dims.slice();
31
outputShape[2] = outputShape[2] / 2;
33
const input = inputVariable('input', inputs[0].dataType, inputs[0].dims, 4);
34
const bias = inputVariable('bias', inputs[0].dataType, [inputs[0].dims[2]], 4);
35
const output = outputVariable('output', inputs[0].dataType, outputShape, 4);
37
const outputSize = ShapeUtil.size(outputShape) / 4;
38
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
40
const getShaderSource = (shaderHelper: ShaderHelper) => `
41
const M_SQRT2 = sqrt(2.0);
42
const halfChannels = ${inputs[0].dims[2] / 4 / 2}u;
44
${shaderHelper.declareVariables(input, bias, output)}
48
${shaderHelper.mainStart()}
49
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
50
let biasIdx = global_idx % halfChannels;
51
let batchIndex = global_idx / halfChannels;
52
let inputOffset = biasIdx + batchIndex * halfChannels * 2;
53
let valueLeft = input[inputOffset] + bias[biasIdx];
54
let valueRight = input[inputOffset + halfChannels] + bias[biasIdx + halfChannels];
55
let geluRight = valueRight * 0.5 * (erf_vf32(valueRight / M_SQRT2) + 1);
57
${output.setByOffset('global_idx', 'valueLeft * geluRight')}
61
name: 'BiasSplitGelu',
63
outputs: [{ dims: outputShape, dataType: inputs[0].dataType }],
64
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
70
export const biasSplitGelu = (context: ComputeContext): void => {
71
validateInputs(context.inputs);
72
context.compute(createBiasSplitGeluProgramInfo(context.inputs));