onnxruntime

Форк
0
81 строка · 3.0 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { ComputeContext, ProgramInfo } from '../types';
8

9
import {
10
  inputVariable,
11
  outputVariable,
12
  ShaderHelper,
13
  tensorTypeToWsglValueType,
14
  UniformsArrayType,
15
  WORKGROUP_SIZE,
16
} from './common';
17
import * as unary from './unary-op';
18

19
// GELU is defined as Y=0.5*X*(1+tanh(0.797885*X+0.035677*X*X*X)), where X may pre-add a bias.
20

21
const createFastGeluProgramInfo = (inputTensors: readonly TensorView[]): ProgramInfo => {
22
  const dataType = inputTensors[0].dataType;
23
  const outputSize = ShapeUtil.size(inputTensors[0].dims);
24
  const biasLength = ShapeUtil.size(inputTensors[1].dims);
25
  // can only use vec4 when bias length is multiple of 4
26
  const useVec4 = biasLength % 4 === 0;
27
  const getShaderSource = (shaderHelper: ShaderHelper): string => {
28
    const x = inputVariable('x', dataType, [1], 4);
29
    const bias = inputVariable('bias', dataType, [1], 4);
30
    const y = outputVariable('y', dataType, [1], 4);
31

32
    const uniforms: UniformsArrayType = [
33
      { name: 'output_vec_size', type: 'u32' },
34
      { name: 'bias_size', type: 'u32' },
35
    ];
36

37
    const singleElementBias = (i: 0 | 1 | 2 | 3) => `
38
      let bias${i}_offset: u32 = (global_idx * 4 + ${i}) % uniforms.bias_size;
39
      let bias${i} = ${bias.getByOffset(`bias${i}_offset / 4`)}[bias${i}_offset % 4];`;
40
    const biasGetExpression = useVec4
41
      ? `
42
      let bias = ${bias.getByOffset('global_idx % (uniforms.bias_size / 4)')};`
43
      : `${singleElementBias(0)}${singleElementBias(1)}${singleElementBias(2)}${singleElementBias(3)}
44
      let bias = ${x.type.value}(bias0, bias1, bias2, bias3);`;
45

46
    return `${shaderHelper.registerUniforms(uniforms).declareVariables(x, bias, y)}
47

48
    ${unary.fastGeluImpl(tensorTypeToWsglValueType(dataType))}
49

50
    ${shaderHelper.mainStart(WORKGROUP_SIZE)}
51
      ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_vec_size')}
52

53
      let x = ${x.getByOffset('global_idx')};
54
      ${biasGetExpression}
55
      let x_in = x + bias;
56
      ${y.setByOffset('global_idx', unary.fastGeluExpression('x_in'))}
57
    }`;
58
  };
59

60
  return {
61
    name: 'FastGeluWithBias',
62
    shaderCache: { hint: `${useVec4}`, inputDependencies: ['type', 'type'] },
63
    getShaderSource,
64
    getRunData: (inputs) => ({
65
      outputs: [{ dims: inputs[0].dims, dataType: inputs[0].dataType }],
66
      programUniforms: [
67
        { type: DataType.uint32, data: Math.ceil(outputSize / 4) },
68
        { type: DataType.uint32, data: biasLength },
69
      ],
70
      dispatchGroup: { x: Math.ceil(outputSize / WORKGROUP_SIZE / 4) },
71
    }),
72
  };
73
};
74

75
export const fastGelu = (context: ComputeContext): void => {
76
  if (context.inputs.length < 2 || ShapeUtil.size(context.inputs[1].dims) === 0) {
77
    unary.fastGelu(context);
78
  } else {
79
    context.compute(createFastGeluProgramInfo(context.inputs));
80
  }
81
};
82

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.