onnxruntime

Форк
0
155 строк · 5.5 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types';
8

9
import {
10
  castToF32,
11
  fillVector,
12
  getMaxComponents,
13
  inputVariable,
14
  outputVariable,
15
  ShaderHelper,
16
  sumVector,
17
  tensorTypeToWsglStorageType,
18
  UniformsArrayType,
19
} from './common';
20

21
interface LayerNormAttributes {
22
  simplified: boolean;
23
  axis: number;
24
  epsilon: number;
25
}
26

27
const validateInputs = (inputs: readonly TensorView[]): void => {
28
  if (!inputs || inputs.length < 2) {
29
    throw new Error('layerNorm requires at least 2 inputs.');
30
  }
31
};
32

33
const createLayerNormProgramInfo = (
34
  inputs: readonly TensorView[],
35
  attributes: LayerNormAttributes,
36
  outputCount: number,
37
): ProgramInfo => {
38
  const simplified = attributes.simplified;
39

40
  const xShape = inputs[0].dims;
41
  const scale = inputs[1];
42
  const bias = !simplified && inputs[2];
43

44
  const outputShape = xShape;
45
  const axis = ShapeUtil.normalizeAxis(attributes.axis, xShape.length);
46
  const normCount = ShapeUtil.sizeToDimension(xShape, axis);
47
  const normSize = ShapeUtil.sizeFromDimension(xShape, axis);
48

49
  const scaleSize = ShapeUtil.size(scale.dims);
50
  const biasSize = bias ? ShapeUtil.size(bias.dims) : 0;
51
  if (scaleSize !== normSize || (bias && biasSize !== normSize)) {
52
    throw new Error(`Size of X.shape()[axis:] == ${normSize}.
53
       Size of scale and bias (if provided) must match this.
54
       Got scale size of ${scaleSize} and bias size of ${biasSize}`);
55
  }
56

57
  const meanInvStdDevDim: number[] = [];
58
  for (let i = 0; i < xShape.length; ++i) {
59
    if (i < axis) {
60
      meanInvStdDevDim.push(xShape[i]);
61
    } else {
62
      meanInvStdDevDim.push(1);
63
    }
64
  }
65
  const components = getMaxComponents(normSize);
66
  const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
67
  const programUniforms: ProgramUniform[] = [
68
    { type: DataType.uint32, data: normCount },
69
    { type: DataType.float, data: normSize },
70
    { type: DataType.uint32, data: Math.floor(normSize / components) },
71
    { type: DataType.float, data: attributes.epsilon },
72
  ];
73
  if (bias) {
74
    inputDependencies.push('type');
75
  }
76
  const hasMeanDataOutput = outputCount > 1;
77
  const hasInvStdOutput = outputCount > 2;
78

79
  const getShaderSource = (shaderHelper: ShaderHelper) => {
80
    const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
81
    const variables = [
82
      inputVariable('x', inputs[0].dataType, inputs[0].dims, components),
83
      inputVariable('scale', scale.dataType, scale.dims, components),
84
    ];
85
    if (bias) {
86
      variables.push(inputVariable('bias', bias.dataType, bias.dims, components));
87
    }
88
    variables.push(outputVariable('output', inputs[0].dataType, outputShape, components));
89
    if (hasMeanDataOutput) {
90
      variables.push(outputVariable('mean_data_output', DataType.float, meanInvStdDevDim));
91
    }
92
    if (hasInvStdOutput) {
93
      variables.push(outputVariable('inv_std_output', DataType.float, meanInvStdDevDim));
94
    }
95

96
    const uniforms: UniformsArrayType = [
97
      { name: 'norm_count', type: 'u32' },
98
      { name: 'norm_size', type: 'f32' },
99
      { name: 'norm_size_vectorized', type: 'u32' },
100
      { name: 'epsilon', type: 'f32' },
101
    ];
102
    return `
103
  ${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)}
104
  ${shaderHelper.mainStart()}
105
    ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.norm_count')}
106
    let offset = global_idx * uniforms.norm_size_vectorized;
107
    var mean_vector = ${fillVector('f32', components)};
108
    var mean_square_vector = ${fillVector('f32', components)};
109

110
    for (var h: u32 = 0u; h < uniforms.norm_size_vectorized; h++) {
111
      let value = ${castToF32(dataType, components, 'x[h + offset]')};
112
      mean_vector += value;
113
      mean_square_vector += value * value;
114
    }
115
    let mean = ${sumVector('mean_vector', components)} / uniforms.norm_size;
116
    let inv_std_dev = inverseSqrt(${sumVector('mean_square_vector', components)} / uniforms.norm_size ${
117
      simplified ? '' : '- mean * mean'
118
    } + uniforms.epsilon);
119

120
    for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) {
121
      let f32input = ${castToF32(dataType, components, 'x[j + offset]')};
122
      let f32scale = ${castToF32(dataType, components, 'scale[j]')};
123
      output[j + offset] = ${variables[0].type.value}((f32input ${simplified ? '' : '- mean'}) * inv_std_dev * f32scale
124
        ${bias ? `+ ${castToF32(dataType, components, 'bias[j]')}` : ''}
125
      );
126
    }
127

128
    ${hasMeanDataOutput ? 'mean_data_output[global_idx] = mean' : ''};
129
    ${hasInvStdOutput ? 'inv_std_output[global_idx] = inv_std_dev' : ''};
130
  }`;
131
  };
132
  const outputs = [{ dims: outputShape, dataType: inputs[0].dataType }];
133
  if (hasMeanDataOutput) {
134
    outputs.push({ dims: meanInvStdDevDim, dataType: DataType.float });
135
  }
136
  if (hasInvStdOutput) {
137
    outputs.push({ dims: meanInvStdDevDim, dataType: DataType.float });
138
  }
139

140
  return {
141
    name: 'LayerNormalization',
142
    shaderCache: { hint: `${components};${outputCount};${simplified}`, inputDependencies },
143
    getRunData: () => ({
144
      outputs,
145
      dispatchGroup: { x: Math.ceil(normCount / 64 /* workgroup size */) },
146
      programUniforms,
147
    }),
148
    getShaderSource,
149
  };
150
};
151

152
export const layerNorm = (context: ComputeContext, attributes: LayerNormAttributes): void => {
153
  validateInputs(context.inputs);
154
  context.compute(createLayerNormProgramInfo(context.inputs, attributes, context.outputCount));
155
};
156

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.