onnxruntime

Форк
0
/
skip-layer-norm.ts 
242 строки · 8.8 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { ComputeContext, ProgramInfo, ProgramUniform } from '../types';
8

9
import {
10
  castToF32,
11
  getMaxComponents,
12
  inputVariable,
13
  outputVariable,
14
  ShaderHelper,
15
  sumVector,
16
  tensorTypeToWsglStorageType,
17
  UniformsArrayType,
18
} from './common';
19

20
export interface SkipLayerNormAttributes {
21
  simplified: boolean;
22
  epsilon: number;
23
}
24

25
const validateInputs = (inputs: readonly TensorView[]): void => {
26
  if (!inputs || inputs.length < 3) {
27
    throw new Error('layerNorm requires at least 3 inputs.');
28
  }
29

30
  const input: TensorView = inputs[0];
31
  const skip: TensorView = inputs[1];
32
  const gamma: TensorView = inputs[2];
33

34
  if (input.dataType !== skip.dataType || input.dataType !== gamma.dataType) {
35
    throw new Error('All inputs must have the same data type');
36
  }
37

38
  if (input.dims.length !== 3 && input.dims.length !== 2) {
39
    throw new Error('Input must be 2D or 3D');
40
  }
41

42
  if (skip.dims.length !== 3 && skip.dims.length !== 2) {
43
    throw new Error('Skip must be 2D or 3D');
44
  }
45

46
  const hiddenSize = input.dims[input.dims.length - 1];
47
  const sequenceLength = input.dims[input.dims.length - 2];
48
  if (skip.dims[skip.dims.length - 1] !== hiddenSize) {
49
    throw new Error('Skip must have the same hidden size as input');
50
  }
51
  if (skip.dims[skip.dims.length - 2] !== sequenceLength) {
52
    throw new Error('Skip must have the same sequence length as input');
53
  }
54

55
  if (gamma.dims.length !== 1) {
56
    throw new Error('Gamma must be 1D');
57
  }
58
  if (gamma.dims[gamma.dims.length - 1] !== hiddenSize) {
59
    throw new Error('Gamma must have the same hidden size as input');
60
  }
61
  if (inputs.length > 3) {
62
    const beta: TensorView = inputs[3];
63
    if (beta.dims.length !== 1) {
64
      throw new Error('Beta must be 1D');
65
    }
66
    if (beta.dims[beta.dims.length - 1] !== hiddenSize) {
67
      throw new Error('Beta must have the same hidden size as input');
68
    }
69
  }
70
  if (inputs.length > 4) {
71
    const bias: TensorView = inputs[4];
72
    if (bias.dims.length !== 1) {
73
      throw new Error('Bias must be 1D');
74
    }
75
    if (bias.dims[bias.dims.length - 1] !== hiddenSize) {
76
      throw new Error('Bias must have the same hidden size as input');
77
    }
78
  }
79
};
80

81
const createSkipLayerNormProgramInfo = (
82
  inputs: readonly TensorView[],
83
  attributes: SkipLayerNormAttributes,
84
  outputCount: number,
85
  isTraining: boolean,
86
): ProgramInfo => {
87
  const simplified = attributes.simplified;
88

89
  const inputShape = inputs[0].dims;
90
  const inputSize = ShapeUtil.size(inputShape);
91
  const outputShape = inputShape;
92
  const outputSize = inputSize;
93
  const hiddenSize = inputShape.slice(-1)[0];
94
  const meanInvStdDevDim = isTraining ? inputShape.slice(0, -1).concat(1) : [];
95
  const hasBetaInput = !simplified && inputs.length > 3;
96
  const hasBiasInput = inputs.length > 4;
97
  const hasMeanOutput = isTraining && outputCount > 1;
98
  const hasInvStdDevOutput = isTraining && outputCount > 2;
99
  const hasInputSkipBiasSumOutput = outputCount > 3;
100
  const workgroupSize = 64;
101

102
  const components = getMaxComponents(hiddenSize);
103

104
  const programUniforms: ProgramUniform[] = [
105
    { type: DataType.uint32, data: outputSize },
106
    { type: DataType.uint32, data: components },
107
    { type: DataType.uint32, data: hiddenSize },
108
    { type: DataType.float, data: attributes.epsilon },
109
  ];
110
  const getShaderSource = (shaderHelper: ShaderHelper) => {
111
    const uniformsArray: UniformsArrayType = [
112
      { name: 'output_size', type: 'u32' },
113
      { name: 'components', type: 'u32' },
114
      { name: 'hidden_size', type: 'u32' },
115
      { name: 'epsilon', type: 'f32' },
116
    ];
117
    const variables = [
118
      inputVariable('x', inputs[0].dataType, inputs[0].dims, components),
119
      inputVariable('skip', inputs[1].dataType, inputs[1].dims, components),
120
      inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components),
121
    ];
122
    if (hasBetaInput) {
123
      variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components));
124
    }
125
    if (hasBiasInput) {
126
      variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components));
127
    }
128
    variables.push(outputVariable('output', inputs[0].dataType, outputShape, components));
129
    if (hasMeanOutput) {
130
      variables.push(outputVariable('mean_output', DataType.float, meanInvStdDevDim));
131
    }
132
    if (hasInvStdDevOutput) {
133
      variables.push(outputVariable('inv_std_output', DataType.float, meanInvStdDevDim));
134
    }
135
    if (hasInputSkipBiasSumOutput) {
136
      variables.push(outputVariable('input_skip_bias_sum', inputs[0].dataType, outputShape, components));
137
    }
138
    const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
139
    const vecDataType = tensorTypeToWsglStorageType(DataType.float, components);
140
    return `
141

142
      ${shaderHelper.registerUniforms(uniformsArray).declareVariables(...variables)}
143
      var<workgroup> sum_shared : array<${vecDataType}, ${workgroupSize}>;
144
      var<workgroup> sum_squared_shared : array<${vecDataType}, ${workgroupSize}>;
145

146
      ${shaderHelper.mainStart([workgroupSize, 1, 1])}
147
        let ix = local_id.x;
148
        let iy = global_id.x / ${workgroupSize};
149

150
        let hidden_size_vectorized: u32 = uniforms.hidden_size / uniforms.components;
151
        var stride = hidden_size_vectorized / ${workgroupSize};
152
        let offset = ix * stride + iy * hidden_size_vectorized;
153
        let offset1d = stride * ix;
154
        if (ix == ${workgroupSize - 1}) {
155
          stride = hidden_size_vectorized - stride * ix;
156
        }
157
        for (var i: u32 = 0; i < stride; i++) {
158
          let skip_value = skip[offset + i];
159
          let bias_value = ${hasBiasInput ? 'bias[offset1d + i]' : dataType + '(0.0)'};
160
          let input_value = x[offset + i];
161
          let value = input_value + skip_value + bias_value;
162
          ${hasInputSkipBiasSumOutput ? 'input_skip_bias_sum[offset + i] = value;' : ''}
163
          output[offset + i] = value;
164
          let f32_value = ${castToF32(dataType, components, 'value')};
165
          sum_shared[ix] += f32_value;
166
          sum_squared_shared[ix] += f32_value * f32_value;
167
        }
168
        workgroupBarrier();
169

170
        var reduce_size : u32 = ${workgroupSize};
171
        for (var curr_size = reduce_size >> 1;  curr_size > 0; curr_size = reduce_size >> 1) {
172
          reduce_size = curr_size + (reduce_size & 1);
173
          if (ix < curr_size) {
174
            sum_shared[ix] += sum_shared[ix + reduce_size];
175
            sum_squared_shared[ix] += sum_squared_shared[ix + reduce_size];
176
          }
177
          workgroupBarrier();
178
        }
179

180
        let sum = sum_shared[0];
181
        let square_sum = sum_squared_shared[0];
182
        let mean = ${sumVector('sum', components)} / f32(uniforms.hidden_size);
183
        let inv_std_dev = inverseSqrt(${sumVector('square_sum', components)} / f32(uniforms.hidden_size) ${
184
          simplified ? '' : '- mean * mean'
185
        } + uniforms.epsilon);
186
        ${hasMeanOutput ? 'mean_output[global_idx] = mean;' : ''}
187
        ${hasInvStdDevOutput ? 'inv_std_output[global_idx] = inv_std_dev;' : ''}
188

189
        for (var i: u32 = 0; i < stride; i++) {
190
          output[offset + i] = (output[offset + i] ${simplified ? '' : `- ${dataType}(mean)`}) *
191
            ${dataType}(inv_std_dev) * gamma[offset1d + i]
192
            ${hasBetaInput ? '+ beta[offset1d + i]' : ''};
193
        }
194
      }`;
195
  };
196
  const outputs = [{ dims: outputShape, dataType: inputs[0].dataType }];
197
  if (outputCount > 1) {
198
    outputs.push({ dims: meanInvStdDevDim, dataType: DataType.float });
199
  }
200
  if (outputCount > 2) {
201
    outputs.push({ dims: meanInvStdDevDim, dataType: DataType.float });
202
  }
203
  if (outputCount > 3) {
204
    outputs.push({ dims: inputShape, dataType: inputs[0].dataType });
205
  }
206
  return {
207
    name: 'SkipLayerNormalization',
208
    shaderCache: {
209
      hint: `${components};${hasMeanOutput};${hasInvStdDevOutput};${hasInputSkipBiasSumOutput}`,
210
      inputDependencies: inputs.map((_input, _index) => 'type'),
211
    },
212
    getShaderSource,
213
    getRunData: () => ({
214
      outputs,
215
      dispatchGroup: {
216
        x: Math.ceil(outputSize / hiddenSize),
217
      },
218
      programUniforms,
219
    }),
220
  };
221
};
222

223
export const skipLayerNorm = (context: ComputeContext, attributes: SkipLayerNormAttributes): void => {
224
  // TODO: initialize isTraining from ComputeContext
225
  const isTraining = false;
226
  validateInputs(context.inputs);
227
  // Mean and InvStdDev are only used in training mode and are not required for inference.
228
  // They are added here for completeness only.
229
  const outputs = [0];
230
  if (context.outputCount > 1) {
231
    outputs.push(isTraining ? 1 : -3);
232
  }
233
  if (context.outputCount > 2) {
234
    outputs.push(isTraining ? 2 : -3);
235
  }
236
  if (context.outputCount > 3) {
237
    outputs.push(3);
238
  }
239
  context.compute(createSkipLayerNormProgramInfo(context.inputs, attributes, context.outputCount, isTraining), {
240
    outputs,
241
  });
242
};
243

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.