onnxruntime

Форк
0
/
bias-split-gelu.ts 
73 строки · 2.6 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { TensorView } from '../../tensor-view';
5
import { ShapeUtil } from '../../util';
6
import { ComputeContext, ProgramInfo } from '../types';
7

8
import { inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType } from './common';
9
import { erfImpl } from './unary-op';
10

11
const validateInputs = (inputs: readonly TensorView[]): void => {
12
  if (inputs[0].dims.length !== 3) {
13
    throw new Error('input should have 3 dimensions');
14
  }
15

16
  if (![2560, 5120, 10240].includes(inputs[0].dims[2])) {
17
    throw new Error('hidden state should be 2560, 5120 or 10240');
18
  }
19

20
  if (inputs[1].dims.length !== 1) {
21
    throw new Error('bias is expected to have 1 dimensions');
22
  }
23

24
  if (inputs[0].dims[2] !== inputs[1].dims[0]) {
25
    throw new Error('last dimension of input and bias are not the same');
26
  }
27
};
28

29
const createBiasSplitGeluProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => {
30
  const outputShape = inputs[0].dims.slice();
31
  outputShape[2] = outputShape[2] / 2;
32

33
  const input = inputVariable('input', inputs[0].dataType, inputs[0].dims, 4);
34
  const bias = inputVariable('bias', inputs[0].dataType, [inputs[0].dims[2]], 4);
35
  const output = outputVariable('output', inputs[0].dataType, outputShape, 4);
36

37
  const outputSize = ShapeUtil.size(outputShape) / 4;
38
  const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
39

40
  const getShaderSource = (shaderHelper: ShaderHelper) => `
41
  const M_SQRT2 = sqrt(2.0);
42
  const halfChannels = ${inputs[0].dims[2] / 4 / 2}u;
43

44
  ${shaderHelper.declareVariables(input, bias, output)}
45

46
  ${erfImpl(dataType)}
47

48
  ${shaderHelper.mainStart()}
49
    ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
50
    let biasIdx = global_idx % halfChannels;
51
    let batchIndex = global_idx / halfChannels;
52
    let inputOffset = biasIdx + batchIndex * halfChannels * 2;
53
    let valueLeft = input[inputOffset] + bias[biasIdx];
54
    let valueRight = input[inputOffset + halfChannels] + bias[biasIdx + halfChannels];
55
    let geluRight = valueRight * 0.5 * (erf_vf32(valueRight / M_SQRT2) + 1);
56

57
    ${output.setByOffset('global_idx', 'valueLeft * geluRight')}
58
  }`;
59

60
  return {
61
    name: 'BiasSplitGelu',
62
    getRunData: () => ({
63
      outputs: [{ dims: outputShape, dataType: inputs[0].dataType }],
64
      dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
65
    }),
66
    getShaderSource,
67
  };
68
};
69

70
export const biasSplitGelu = (context: ComputeContext): void => {
71
  validateInputs(context.inputs);
72
  context.compute(createBiasSplitGeluProgramInfo(context.inputs));
73
};
74

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.