onnxruntime

Форк
0
/
rotary-embedding.ts 
184 строки · 7.9 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { createAttributeWithCacheKey } from '../attribute-with-cache-key';
8
import { ComputeContext, ProgramInfo, ProgramUniform } from '../types';
9

10
import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, WORKGROUP_SIZE } from './common';
11

12
export interface RotaryEmbeddingAttributes {
13
  readonly interleaved: boolean;
14
  readonly numHeads: number;
15
  readonly rotaryEmbeddingDim: number;
16
  readonly scale: number;
17
}
18

19
const validateInputs = (inputs: readonly TensorView[], attributes: RotaryEmbeddingAttributes): void => {
20
  const [input, positionIds, cosCache, sinCache] = inputs;
21
  const { numHeads, rotaryEmbeddingDim } = attributes;
22

23
  if (input.dims.length !== 3 && input.dims.length !== 4) {
24
    throw new Error(`Input 'x' is expected to have 3 or 4 dimensions, got ${input.dims.length}`);
25
  }
26
  if (
27
    !ShapeUtil.areEqual(positionIds.dims, []) &&
28
    !ShapeUtil.areEqual(positionIds.dims, [1]) &&
29
    positionIds.dims.length !== 2
30
  ) {
31
    throw new Error(`Input 'position_ids' is expected to have 0, 1, or 2 dimensions, got ${positionIds.dims.length}`);
32
  }
33
  if (cosCache.dims.length !== 2) {
34
    throw new Error(`Input 'cos_cache' is expected to have 2 dimensions, got ${cosCache.dims.length}`);
35
  }
36
  if (sinCache.dims.length !== 2) {
37
    throw new Error(`Input 'sin_cache' is expected to have 2 dimensions, got ${sinCache.dims.length}`);
38
  }
39
  if (!ShapeUtil.areEqual(cosCache.dims, sinCache.dims)) {
40
    throw new Error("Inputs 'cos_cache' and 'sin_cache' are expected to have the same shape");
41
  }
42

43
  if (rotaryEmbeddingDim > 0 && numHeads === 0) {
44
    throw new Error('num_heads must be provided if rotary_embedding_dim is specified');
45
  }
46

47
  const batchSize = input.dims[0];
48
  const sequenceLength = input.dims[input.dims.length - 2];
49
  const maxSequenceLength = cosCache.dims[0];
50
  const hiddenSize = ShapeUtil.sizeFromDimension(input.dims, 1) / sequenceLength;
51
  const headSize = rotaryEmbeddingDim === 0 ? cosCache.dims[1] * 2 : hiddenSize / numHeads;
52
  if (rotaryEmbeddingDim > headSize) {
53
    throw new Error('rotary_embedding_dim must be less than or equal to head_size');
54
  }
55

56
  if (positionIds.dims.length === 2) {
57
    if (batchSize !== positionIds.dims[0]) {
58
      throw new Error(`Input 'position_ids' dimension 0 should be of size batch_size, got ${positionIds.dims[0]}`);
59
    }
60
    if (sequenceLength !== positionIds.dims[1]) {
61
      throw new Error(`Input 'position_ids' dimension 1 should be of size sequence_length, got ${positionIds.dims[1]}`);
62
    }
63
  }
64

65
  if (headSize / 2 !== cosCache.dims[1] && rotaryEmbeddingDim / 2 !== cosCache.dims[1]) {
66
    throw new Error(
67
      `Input 'cos_cache' dimension 1 should be same as head_size / 2 or rotary_embedding_dim / 2, got ${
68
        cosCache.dims[1]
69
      }`,
70
    );
71
  }
72

73
  if (sequenceLength > maxSequenceLength) {
74
    throw new Error('Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported');
75
  }
76
};
77

78
const createRotaryEmbeddingProgramInfo = (
79
  inputs: readonly TensorView[],
80
  attributes: RotaryEmbeddingAttributes,
81
): ProgramInfo => {
82
  const { interleaved, numHeads, rotaryEmbeddingDim, scale } = attributes;
83
  const batchSize = inputs[0].dims[0];
84
  const batchStride = ShapeUtil.sizeFromDimension(inputs[0].dims, 1);
85
  const sequenceLength = inputs[0].dims[inputs[0].dims.length - 2];
86
  const hiddenSize = batchStride / sequenceLength;
87
  const halfRotaryEmbeddingDim = inputs[2].dims[1];
88
  const headSize = rotaryEmbeddingDim === 0 ? halfRotaryEmbeddingDim * 2 : hiddenSize / numHeads;
89

90
  // Rotary embeddings will be calculated in a pair-wise fashion. In accordance, use the shape
91
  // [batch size, sequence length, num of heads, num of pairs to rotate + num of dims to copy]
92
  // to unfold the global index in shader.
93
  const globalShape = new Array<number>(
94
    batchSize,
95
    sequenceLength,
96
    hiddenSize / headSize,
97
    headSize - halfRotaryEmbeddingDim,
98
  );
99
  const globalStrides = ShapeUtil.computeStrides(globalShape);
100

101
  const programUniforms: ProgramUniform[] = [
102
    { type: DataType.float, data: scale },
103
    { type: DataType.uint32, data: globalShape },
104
    { type: DataType.uint32, data: globalStrides },
105

106
    // strides for addressing the input/output tensor, in permutated order to align with the unfolded global index,
107
    // i.e. BSNH
108
    ...(inputs[0].dims.length === 3
109
      ? new Array<ProgramUniform>({ type: DataType.uint32, data: [batchStride, hiddenSize, headSize, 1] })
110
      : []),
111
    ...(inputs[0].dims.length === 4
112
      ? new Array<ProgramUniform>({
113
          type: DataType.uint32,
114
          data: [batchStride, headSize, sequenceLength * headSize, 1],
115
        })
116
      : []),
117

118
    ...createTensorShapeVariables(inputs[0].dims, inputs[1].dims, inputs[2].dims, inputs[3].dims, inputs[0].dims),
119
  ];
120

121
  const getShaderSource = (shaderHelper: ShaderHelper) => {
122
    const input = inputVariable('input', inputs[0].dataType, inputs[0].dims.length);
123
    const positionIds = inputVariable('position_ids', inputs[1].dataType, inputs[1].dims.length);
124
    const cosCache = inputVariable('cos_cache', inputs[2].dataType, inputs[2].dims.length);
125
    const sinCache = inputVariable('sin_cache', inputs[3].dataType, inputs[3].dims.length);
126
    const output = outputVariable('output', inputs[0].dataType, inputs[0].dims.length);
127

128
    shaderHelper.registerUniforms([
129
      { name: 'scale', type: 'f32' },
130
      { name: 'global_shape', type: 'u32', length: globalShape.length },
131
      { name: 'global_strides', type: 'u32', length: globalStrides.length },
132
      { name: 'input_output_strides', type: 'u32', length: globalStrides.length },
133
    ]);
134

135
    return `
136
        ${shaderHelper.declareVariables(input, positionIds, cosCache, sinCache, output)}
137

138
        ${shaderHelper.mainStart(WORKGROUP_SIZE)}
139
          let half_rotary_emb_dim = uniforms.${cosCache.name}_shape[1];
140
          let bsnh = global_idx / uniforms.global_strides % uniforms.global_shape;
141
          let size = uniforms.global_shape[0] * uniforms.global_strides[0];
142
          ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('size')}
143

144
          if (bsnh[3] < half_rotary_emb_dim) {
145
            let position_ids_idx =
146
                ${positionIds.broadcastedIndicesToOffset('bsnh.xy', outputVariable('', positionIds.type.tensor, 2))};
147
            let position_id =
148
                u32(${positionIds.getByOffset('position_ids_idx')}) + select(0, bsnh[1], position_ids_idx == 0);
149
            let i = dot(bsnh, uniforms.input_output_strides) + select(0, bsnh[3], ${interleaved});
150
            let j = i + select(half_rotary_emb_dim, 1, ${interleaved});
151
            let re = ${input.getByOffset('i')} * ${cosCache.get('position_id', 'bsnh[3]')} -
152
                ${input.getByOffset('j')} * ${sinCache.get('position_id', 'bsnh[3]')};
153
            ${output.setByOffset('i', 're')}
154
            let im = ${input.getByOffset('i')} * ${sinCache.get('position_id', 'bsnh[3]')} +
155
                ${input.getByOffset('j')} * ${cosCache.get('position_id', 'bsnh[3]')};
156
            ${output.setByOffset('j', 'im')}
157
          } else {
158
            let k = dot(bsnh, uniforms.input_output_strides) + half_rotary_emb_dim;
159
            ${output.setByOffset('k', input.getByOffset('k'))}
160
          }
161
        }`;
162
  };
163

164
  return {
165
    name: 'RotaryEmbedding',
166
    shaderCache: {
167
      hint: createAttributeWithCacheKey({
168
        interleaved,
169
      }).cacheKey,
170
      inputDependencies: ['rank', 'rank', 'rank', 'rank'],
171
    },
172
    getShaderSource,
173
    getRunData: () => ({
174
      outputs: [{ dims: inputs[0].dims, dataType: inputs[0].dataType }],
175
      dispatchGroup: { x: Math.ceil(ShapeUtil.size(globalShape) / WORKGROUP_SIZE) },
176
      programUniforms,
177
    }),
178
  };
179
};
180

181
export const rotaryEmbedding = (context: ComputeContext, attributes: RotaryEmbeddingAttributes): void => {
182
  validateInputs(context.inputs, attributes);
183
  context.compute(createRotaryEmbeddingProgramInfo(context.inputs, attributes));
184
};
185

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.