onnxruntime

Форк
0
/
group-query-attention.ts 
382 строки · 14.0 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { createAttributeWithCacheKey } from '../attribute-with-cache-key';
8
import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types';
9

10
import {
11
  applyAttention,
12
  AttentionAttrs,
13
  AttentionMaskType,
14
  AttentionParameters,
15
  AttentionQkvFormat,
16
} from './attention';
17
import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType } from './common';
18
import { maybeTransposeToBNSHAndAddBias } from './multihead-attention';
19
import { createTileProgramInfo } from './tile';
20
import { createTransposeProgramInfo, TransposeAttributes } from './transpose';
21

22
export const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => {
23
  const query = inputs[0];
24
  const key = inputs[1];
25
  const value = inputs[2];
26
  const pastKey = inputs[3];
27
  const pastValue = inputs[4];
28

29
  // Abbreviation and Meanings:
30
  //   B:    batch_size
31
  //   S:    sequence_length (input sequence length of query)
32
  //   P:    past_sequence_length (past sequence length of key or value)
33
  //   L:    kv_sequence_length (input sequence length of key or value)
34
  //   M:    max_sequence_length
35
  //   T:    total_sequence_length = past_sequence_length + kv_sequence_length
36
  //   N:    num_heads
37
  //   H:    head size for Q and K, aka q_head_size or k_head_size or qk_head_size
38
  //   H_v:  v_head_size
39
  //   D_i:  input hidden size
40
  //   D:    hidden size for Q and K (D = N * H), aka q_hidden_size or k_hidden_size or qk_hidden_size
41
  //   D_v:  v_hidden_size = num_heads * v_head_size
42

43
  //     past_key                   : (B, N, S*, H)
44
  //     past_value                 : (B, N, S*, H)
45
  // When no packing for q/k/v:
46
  //     query            (Q)       : (B, S, D)
47
  //     key              (K)       : (B, L, D) or (B, N, S*, H)
48
  //     value            (V)       : (B, L, D_v) or (B, N, S*, H)
49
  // When packed kv is used:
50
  //     query            (Q)       : (B, S, D)
51
  //     key              (K)       : (B, L, N, 2, H)
52
  //     value            (V)       : None
53
  // When packed qkv is used:
54
  //     query            (Q)       : (B, L, N, 3, H) or (B, S, 3*D)
55
  //     key              (K)       : None
56
  //     value            (V)       : None
57

58
  if (query.dims.length !== 3 && query.dims.length !== 5) {
59
    throw new Error('Input query is expected to have 3 or 5 dimensions');
60
  }
61

62
  const dmmhaPacking = false;
63
  const batchSize = query.dims[0];
64
  const sequenceLength = query.dims[1];
65
  const hiddenSize =
66
    query.dims.length === 3 ? (dmmhaPacking ? query.dims[2] / 3 : query.dims[2]) : attributes.numHeads * query.dims[4];
67
  let kvSequenceLength = sequenceLength;
68

69
  let pastSequenceLength = 0;
70
  let maxSequenceLength = 0;
71
  const headSize = Math.floor(hiddenSize / attributes.numHeads);
72
  const hasPastKey = pastKey && pastKey.dims.length !== 0;
73
  const hasPastValue = pastValue && pastValue.dims.length !== 0;
74
  // TODO : this should be from attributes.
75
  const isPastkvBSNH = true;
76
  if (hasPastKey && hasPastValue) {
77
    if (pastKey.dims.length !== 4) {
78
      throw new Error('Input "past_key" is expected to have 4 dimensions');
79
    }
80
    if (pastValue.dims.length !== 4) {
81
      throw new Error('Input "past_value" is expected to have 4 dimensions');
82
    }
83
    if (isPastkvBSNH) {
84
      // For BSNH
85
      pastSequenceLength = pastKey.dims[1];
86
      maxSequenceLength = pastKey.dims[1];
87
    } else {
88
      // For BNSH
89
      pastSequenceLength = pastKey.dims[2];
90
      maxSequenceLength = pastKey.dims[2];
91
    }
92
  } else if (hasPastKey || hasPastValue) {
93
    throw new Error('Input "past_key" and "past_value" shall be both present or both absent');
94
  }
95

96
  let qkvFormat: AttentionQkvFormat;
97
  if (key) {
98
    if (query.dims.length !== 3) {
99
      throw new Error('Input "query" is expected to have 3 dimensions when key is given');
100
    }
101
    if (key.dims.length < 3 || key.dims.length > 5) {
102
      throw new Error('Input "key" is expected to have 3, 4, or 5 dimensions');
103
    }
104
    if (query.dims[0] !== key.dims[0]) {
105
      throw new Error('Input "query" and "key" shall have same dim 0 (batch size)');
106
    }
107

108
    if (key.dims.length === 3) {
109
      if (query.dims[2] % key.dims[2] !== 0) {
110
        throw new Error('Dimension 2 of "query" should be a multiple of "key"');
111
      }
112
      qkvFormat = AttentionQkvFormat.qkvBSNH;
113
      kvSequenceLength = key.dims[1];
114
    } else if (key.dims.length === 5) {
115
      if (key.dims[2] !== attributes.numHeads || key.dims[3] !== 2 || key.dims[4] !== headSize) {
116
        throw new Error('Expect "key" shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv');
117
      }
118
      if (value) {
119
        throw new Error('Expect "value" be none when "key" has packed kv format.');
120
      }
121
      qkvFormat = AttentionQkvFormat.qKvBSNHxBSN2H;
122
      kvSequenceLength = key.dims[1];
123
    } else {
124
      // key_dims.size() == 4 (cross-attention with past_key)
125
      if (key.dims[1] !== attributes.numHeads || key.dims[3] !== headSize) {
126
        throw new Error('Expect "key" shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key');
127
      }
128

129
      qkvFormat = AttentionQkvFormat.unknown;
130
      kvSequenceLength = key.dims[2];
131
    }
132
  } else {
133
    // packed QKV
134
    if (query.dims.length !== 3 && query.dims.length !== 5) {
135
      throw new Error('Input "query" is expected to have 3 or 5 dimensions when key is empty');
136
    }
137
    if (query.dims.length === 5 && (query.dims[2] !== attributes.numHeads || query.dims[3] !== 3)) {
138
      throw new Error('Expect "query" shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv');
139
    }
140

141
    qkvFormat = AttentionQkvFormat.qkvBSN3H;
142
  }
143

144
  const maskType: AttentionMaskType = AttentionMaskType.none;
145
  let passPastInKv = false;
146
  let vHiddenSize = hiddenSize;
147
  if (value) {
148
    if (value.dims.length !== 3 && value.dims.length !== 4) {
149
      throw new Error('Input "value" is expected to have 3 or 4 dimensions');
150
    }
151

152
    if (query.dims[0] !== value.dims[0]) {
153
      throw new Error('Input "query" and "value" shall have same dim 0 (batch_size)');
154
    }
155

156
    if (value.dims.length === 3) {
157
      if (kvSequenceLength !== value.dims[1]) {
158
        throw new Error('Input "key" and "value" shall have the same dim 1 (kv_sequence_length)');
159
      }
160
      vHiddenSize = value.dims[2];
161
    } else {
162
      if (kvSequenceLength !== value.dims[2]) {
163
        throw new Error('Input "past_key" and "past_value" shall have the same dim 2 (kv_sequence_length)');
164
      }
165
      vHiddenSize = value.dims[1] * value.dims[3];
166
      passPastInKv = true;
167
    }
168
  }
169
  const totalSequenceLength = pastSequenceLength + kvSequenceLength;
170
  const broadcastResPosBias = false;
171

172
  return {
173
    batchSize,
174
    sequenceLength,
175
    pastSequenceLength,
176
    kvSequenceLength,
177
    totalSequenceLength,
178
    maxSequenceLength,
179
    inputHiddenSize: 0,
180
    hiddenSize,
181
    vHiddenSize,
182
    headSize,
183
    vHeadSize: Math.floor(vHiddenSize / attributes.kvNumHeads!),
184
    numHeads: attributes.numHeads,
185
    kvNumHeads: attributes.kvNumHeads,
186
    nReps: attributes.numHeads / attributes.kvNumHeads!,
187
    pastPresentShareBuffer: false,
188
    maskType,
189
    scale: attributes.scale,
190
    broadcastResPosBias,
191
    passPastInKv,
192
    qkvFormat,
193
    isPastkvBSNH,
194
  };
195
};
196

197
const createConcatProgramInfo = (
198
  a: TensorView,
199
  b: TensorView | undefined,
200
  dataType: DataType,
201
  params: AttentionParameters,
202
): ProgramInfo => {
203
  const outputShape = [params.batchSize, params.totalSequenceLength, params.kvNumHeads!, params.headSize];
204
  const component = 4;
205
  const outputSize = ShapeUtil.size(outputShape) / component;
206
  const presentSequenceLength = params.totalSequenceLength;
207
  const output = outputVariable('present_kv', dataType, outputShape.length, component);
208
  const inputA = inputVariable('new_kv', a.dataType, a.dims.length, component);
209
  const inputB = b ? inputVariable('past_kv', b.dataType, b.dims.length, component) : undefined;
210

211
  const H = Math.ceil(params.headSize / component);
212
  const dispatch = { x: presentSequenceLength, y: a.dims[0], z: 1 };
213

214
  const inputDependencies: ProgramInputTensorInfoDependency[] = b ? ['rank', 'rank'] : ['rank'];
215

216
  const programUniforms: ProgramUniform[] = [
217
    { type: DataType.uint32, data: outputSize },
218
    { type: DataType.uint32, data: params.pastSequenceLength },
219
    { type: DataType.uint32, data: params.kvSequenceLength },
220
    { type: DataType.uint32, data: params.totalSequenceLength },
221
  ];
222

223
  const inputs = [inputA];
224
  if (inputB) {
225
    programUniforms.push(
226
      ...createTensorShapeVariables(a.dims),
227
      ...createTensorShapeVariables(b!.dims),
228
      ...createTensorShapeVariables(outputShape),
229
    );
230
    inputs.push(inputB);
231
  } else {
232
    programUniforms.push(...createTensorShapeVariables(a.dims), ...createTensorShapeVariables(outputShape));
233
  }
234
  const uniforms: UniformsArrayType = [
235
    { name: 'output_size', type: 'u32' },
236
    { name: 'past_seqlen', type: 'u32' },
237
    { name: 'new_seqlen', type: 'u32' },
238
    { name: 'present_seqlen', type: 'u32' },
239
  ];
240

241
  const pastStr = `      let past_batch_stride = uniforms.past_seqlen * num_heads * H;
242
        var past_head_stride = uniforms.past_seqlen * H;
243
        if (is_bsnh) {
244
          past_head_stride = H;
245
        }
246
        let in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h;
247
        present_kv[out_offset] = past_kv[in_offset];`;
248
  const newStr = `      let new_batch_stride = uniforms.new_seqlen * num_heads * H;
249
        let new_row_stride = num_heads * H;
250
        let new_head_stride = H;
251
        let in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h;
252
        present_kv[out_offset] = new_kv[in_offset];`;
253
  const concatStr = b
254
    ? `if (s < past_seqlen) {
255
        ${pastStr}
256
        } else if (s < past_seqlen + uniforms.new_seqlen) {
257
        ${newStr}
258
        }`
259
    : `if (s < past_seqlen + uniforms.new_seqlen) {
260
          ${newStr}
261
        }`;
262

263
  // TODO: handle H * params.kvNumHeads greater than maxComputeInvocationsPerWorkgroup limit.
264
  const getShaderSource = (shaderHelper: ShaderHelper) => `
265

266
  ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputs, output)}
267
  ${shaderHelper.mainStart([H, params.kvNumHeads!, 1])}
268
    ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
269
    var indices = ${output.offsetToIndices('global_idx')};
270
    let h = local_id.x;
271
    let n = local_id.y;
272
    let s = workgroup_id.x;
273
    let b = workgroup_id.y;
274
    let num_heads = ${params.kvNumHeads!}u;
275
    let H = ${H}u;
276

277
    let present_seqlen = uniforms.present_seqlen;
278
    let present_batch_stride = present_seqlen * num_heads * H;
279
    var row_stride = H;
280
    let is_bsnh = ${params.isPastkvBSNH};
281

282
    if (is_bsnh) {
283
      row_stride = num_heads * H;
284
    }
285
    var present_head_stride = present_seqlen * H;
286
    if (is_bsnh) {
287
      present_head_stride = H;
288
    }
289

290
    let past_seqlen = uniforms.past_seqlen;
291

292
    let out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h;
293
    ${concatStr}
294
  }`;
295

296
  return {
297
    name: 'ConcatPastNew',
298
    shaderCache: { hint: `${params.kvNumHeads!}${H}${!!b}`, inputDependencies },
299
    getRunData: () => ({
300
      outputs: [{ dims: outputShape, dataType }],
301
      dispatchGroup: dispatch,
302
      programUniforms,
303
    }),
304
    getShaderSource,
305
  };
306
};
307

308
export const parseGroupQueryAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs =>
309
  createAttributeWithCacheKey({ ...attributes });
310

311
const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({ perm: [0, 2, 1, 3] });
312

313
const maybeExpandAndTransposeToBNSH = (
314
  context: ComputeContext,
315
  input: TensorView,
316
  pastKV: TensorView | undefined,
317
  params: AttentionParameters,
318
  outputIndex: number,
319
) => {
320
  let reshapedInput = input;
321
  const numHeads = params.kvNumHeads!;
322
  const nReps = params.nReps!;
323
  if (input.dims.length === 3 && params.kvSequenceLength !== 0) {
324
    reshapedInput = input.reshape([params.batchSize, params.kvSequenceLength, numHeads, params.headSize]);
325
  }
326

327
  if (pastKV) {
328
    reshapedInput = context.compute(createConcatProgramInfo(reshapedInput, pastKV, reshapedInput.dataType, params), {
329
      inputs: [reshapedInput, pastKV],
330
      outputs: [params.isPastkvBSNH ? outputIndex : -1],
331
    })[0];
332
  } else {
333
    reshapedInput = context.compute(createConcatProgramInfo(reshapedInput, undefined, reshapedInput.dataType, params), {
334
      inputs: [reshapedInput],
335
      outputs: [params.isPastkvBSNH ? outputIndex : -1],
336
    })[0];
337
  }
338
  if (nReps !== 1) {
339
    reshapedInput = context.compute(createTileProgramInfo([reshapedInput], [1, 1, 1, nReps]), {
340
      inputs: [reshapedInput],
341
      outputs: [-1],
342
    })[0];
343
    reshapedInput = reshapedInput.reshape([
344
      params.batchSize,
345
      params.totalSequenceLength,
346
      numHeads * nReps,
347
      params.headSize,
348
    ]);
349
  }
350

351
  return context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), {
352
    inputs: [reshapedInput],
353
    outputs: [-1],
354
  })[0];
355
};
356

357
export const groupQueryAttention = (context: ComputeContext, attributes: AttentionAttrs): void => {
358
  const params = validateInputs(context.inputs, attributes);
359
  if (context.inputs[0].dims.length === 5) {
360
    throw new Error('Packed QKV is not implemented');
361
  }
362

363
  if (context.inputs[1]?.dims.length === 5) {
364
    throw new Error('Packed KV is not implemented');
365
  }
366

367
  const Q = maybeTransposeToBNSHAndAddBias(
368
    context,
369
    params.batchSize,
370
    params.numHeads,
371
    params.sequenceLength,
372
    params.headSize,
373
    context.inputs[0],
374
    undefined,
375
    0,
376
  );
377
  const pastKey = context.inputs[3] && context.inputs[3].dims.length !== 0 ? context.inputs[3] : undefined;
378
  const pastValue = context.inputs[4] && context.inputs[4].dims.length !== 0 ? context.inputs[4] : undefined;
379
  const K = maybeExpandAndTransposeToBNSH(context, context.inputs[1], pastKey, params, 1);
380
  const V = maybeExpandAndTransposeToBNSH(context, context.inputs[2], pastValue, params, 2);
381
  applyAttention(context, Q, K, V, undefined, undefined, undefined, undefined, undefined, params, attributes);
382
};
383

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.