onnxruntime

Форк
0
/
multihead-attention.ts 
440 строк · 15.4 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { createAttributeWithCacheKey } from '../attribute-with-cache-key';
8
import { ComputeContext, GpuDataType, ProgramUniform } from '../types';
9

10
import {
11
  applyAttention,
12
  AttentionAttrs,
13
  AttentionMaskType,
14
  AttentionParameters,
15
  AttentionQkvFormat,
16
} from './attention';
17
import { inputVariable, outputVariable, ShaderHelper, UniformsArrayType } from './common';
18
import { createTransposeProgramInfo, TransposeAttributes } from './transpose';
19

20
const getInput = (inputs: readonly TensorView[], i: number) =>
21
  inputs.length > i && inputs[i].dims.length > 0 ? inputs[i] : undefined;
22

23
const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => {
24
  const query = inputs[0];
25
  const key = getInput(inputs, 1);
26
  const value = getInput(inputs, 2);
27
  const bias = getInput(inputs, 3);
28
  const keyPaddingMask = getInput(inputs, 4);
29
  const attentionBias = getInput(inputs, 5);
30
  const pastKey = getInput(inputs, 6);
31
  const pastValue = getInput(inputs, 7);
32

33
  // ---------------------------------------------------------------
34
  // Notations:
35
  //    B: batch_size
36
  //    N: num_heads
37
  //    H: head_size of Q and K
38
  //    H_v: head_size of V
39
  //    D: hidden_size for Q and K, where D = N * H
40
  //    D_v: hidden_size of V, where D_v = N * H_v
41
  //    S: q_sequence_length
42
  //    P: past_sequence_length of kv cache
43
  //    L: kv_sequence_length
44
  //    T: total_sequence_length = P + L
45
  //    M: max_sequence_length of kv cache when past and present share buffer
46
  // ---------------------------------------------------------------
47
  // MultiHeadAttention inputs:
48
  // ---------------------------------------------------------------
49
  //  Q_K_V_BSNH - no packing:
50
  //     query            (Q)       : (B, S, D)
51
  //     key              (K)       : (B, L, D)
52
  //     value            (V)       : (B, L, D_v)
53
  //  Q_K_V_BSNH_BNSH_BNSH - cross attention (kv cache is not used, L == T, D == D_v):
54
  //     query            (Q)       : (B, S, D)
55
  //     key              (K)       : (B, N, L, H)
56
  //     value            (V)       : (B, N, L, H_v)
57
  //  Q_KV_BSNH_BSN2H - packed kv (kv cache is not used, bias is not allowed for packed kv):
58
  //     query            (Q)       : (B, S, D)
59
  //     key              (K/V)     : (B, L, N, 2, H)
60
  //     value                      : None
61
  //  QKV_BSN3H - packed qkv (kv cache is not used, S == L, D == D_v):
62
  //     query            (Q/K/V)   : (B, S, N, 3, H)
63
  //     key                        : None
64
  //     value                      : None
65
  //
66
  //  Other inputs:
67
  //     bias             (Q/K/V)   : None or (D + D + D_v)
68
  //     key_padding_mask (K/V)     : (B) or (3 * B + 2) or (B, T) or (B, S, T)
69
  //     attention_bias             : None or (B, N, S, T), (1, N, S, T), (B, 1, S, T) or (1, 1, S, T)
70
  //     past_key                   : (B, N, P, H) or None. Past state is only allowed for Q_K_V_BSNH.
71
  //     past_value                 : (B, N, P, H) or None. Past state is only allowed for Q_K_V_BSNH.
72
  //
73
  //  Not Supported:
74
  //     key_padding_mask, packed kv, packed qkv, and broadcast for attention_bias.
75

76
  if (query.dims.length !== 3 && query.dims.length !== 5) {
77
    throw new Error('Input query is expected to have 3 or 5 dimensions');
78
  }
79

80
  const batchSize = query.dims[0];
81
  const sequenceLength = query.dims[1];
82
  const hiddenSize = query.dims.length === 3 ? query.dims[2] : attributes.numHeads * query.dims[4];
83
  let kvSequenceLength = sequenceLength;
84

85
  let pastSequenceLength = 0;
86
  let maxSequenceLength = 0;
87
  const headSize = Math.floor(hiddenSize / attributes.numHeads);
88
  if (pastKey && pastValue && ShapeUtil.size(pastKey.dims) && ShapeUtil.size(pastValue.dims)) {
89
    if (pastKey.dims.length !== 4) {
90
      throw new Error('Input "past_key" is expected to have 4 dimensions');
91
    }
92
    if (pastKey.dims[0] !== batchSize || pastKey.dims[1] !== attributes.numHeads || pastKey.dims[3] !== headSize) {
93
      throw new Error('Input "past_key" shape (batch_size, num_heads, past_sequence_length, head_size)');
94
    }
95
    if (
96
      pastValue.dims[0] !== batchSize ||
97
      pastValue.dims[1] !== attributes.numHeads ||
98
      pastValue.dims[3] !== headSize
99
    ) {
100
      throw new Error('Input "past_value" shape (batch_size, num_heads, past_sequence_length, head_size)');
101
    }
102
    if (pastKey.dims[2] !== pastValue.dims[2]) {
103
      throw new Error('Input "past_key" and "past_value" shall have same dim 2 (past_sequence_length)');
104
    }
105
    if (pastValue.dims.length !== 4) {
106
      throw new Error('Input "past_value" is expected to have 4 dimensions');
107
    }
108
    pastSequenceLength = pastKey.dims[2];
109
    maxSequenceLength = pastKey.dims[2];
110
  } else if ((pastKey && ShapeUtil.size(pastKey.dims)) || (pastValue && ShapeUtil.size(pastValue.dims))) {
111
    throw new Error('Input "past_key" and "past_value" shall be both present or both absent');
112
  }
113

114
  let qkvFormat: AttentionQkvFormat;
115
  if (key && ShapeUtil.size(key.dims) > 0) {
116
    if (query.dims.length !== 3) {
117
      throw new Error('Input "query" is expected to have 3 dimensions when key is given');
118
    }
119
    if (key.dims.length < 3 || key.dims.length > 5) {
120
      throw new Error('Input "key" is expected to have 3, 4, or 5 dimensions');
121
    }
122
    if (query.dims[0] !== key.dims[0]) {
123
      throw new Error('Input "query" and "key" shall have same dim 0 (batch size)');
124
    }
125

126
    if (key.dims.length === 3) {
127
      if (key.dims[2] !== query.dims[2]) {
128
        throw new Error('Input "query" and "key" shall have same dim 2 (hidden_size)');
129
      }
130
      qkvFormat = AttentionQkvFormat.qkvBSNH;
131
      kvSequenceLength = key.dims[1];
132
    } else if (key.dims.length === 5) {
133
      if (key.dims[2] !== attributes.numHeads || key.dims[3] !== 2 || key.dims[4] !== headSize) {
134
        throw new Error('Expect "key" shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv');
135
      }
136
      if (value) {
137
        throw new Error('Expect "value" be none when "key" has packed kv format.');
138
      }
139
      qkvFormat = AttentionQkvFormat.qKvBSNHxBSN2H;
140
      kvSequenceLength = key.dims[1];
141
    } else {
142
      // key_dims.size() == 4 (cross-attention with past_key)
143
      if (key.dims[1] !== attributes.numHeads || key.dims[3] !== headSize) {
144
        throw new Error('Expect "key" shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key');
145
      }
146

147
      qkvFormat = AttentionQkvFormat.unknown; // Q_K_V_BSNH_BNSH_BNSH
148
      kvSequenceLength = key.dims[2];
149
    }
150
  } else {
151
    // packed QKV
152
    if (query.dims.length !== 5) {
153
      throw new Error('Input "query" is expected to have 5 dimensions when key is empty');
154
    }
155
    if (query.dims[2] !== attributes.numHeads || query.dims[3] !== 3) {
156
      throw new Error('Expect "query" shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv');
157
    }
158

159
    qkvFormat = AttentionQkvFormat.qkvBSN3H;
160
  }
161

162
  if (bias && ShapeUtil.size(bias.dims) > 0) {
163
    if (bias.dims.length !== 1) {
164
      throw new Error('Input "bias" is expected to have 1 dimension');
165
    }
166

167
    if (key) {
168
      if (key.dims.length === 5 && key.dims[3] === 2) {
169
        throw new Error('bias is not allowed for packed kv.');
170
      }
171
    }
172
  }
173

174
  const totalSequenceLength = pastSequenceLength + kvSequenceLength;
175

176
  let maskType: AttentionMaskType = AttentionMaskType.none;
177
  if (keyPaddingMask && ShapeUtil.size(keyPaddingMask.dims) > 0) {
178
    maskType = AttentionMaskType.maskUnknown;
179
    const maskDims = keyPaddingMask.dims;
180
    if (maskDims.length === 1) {
181
      if (maskDims[0] === batchSize) {
182
        maskType = AttentionMaskType.mask1dKeySeqLen;
183
      } else if (maskDims[0] === 3 * batchSize + 2) {
184
        maskType = AttentionMaskType.mask1DKeySeqLenStart;
185
      }
186
    } else if (maskDims.length === 2 && maskDims[0] === batchSize && maskDims[1] === totalSequenceLength) {
187
      maskType = AttentionMaskType.mask2dKeyPadding;
188
    }
189
    if (maskType === AttentionMaskType.maskUnknown) {
190
      throw new Error('Input "key_padding_mask" shape shall be (batch_size) or (batch_size, total_sequence_length)');
191
    }
192
    throw new Error('Mask not supported');
193
  }
194

195
  let passPastInKv = false;
196
  let vHiddenSize = hiddenSize;
197
  if (value && ShapeUtil.size(value.dims) > 0) {
198
    if (value.dims.length !== 3 && value.dims.length !== 4) {
199
      throw new Error('Input "value" is expected to have 3 or 4 dimensions');
200
    }
201

202
    if (query.dims[0] !== value.dims[0]) {
203
      throw new Error('Input "query" and "value" shall have same dim 0 (batch_size)');
204
    }
205

206
    if (value.dims.length === 3) {
207
      if (kvSequenceLength !== value.dims[1]) {
208
        throw new Error('Input "key" and "value" shall have the same dim 1 (kv_sequence_length)');
209
      }
210
      vHiddenSize = value.dims[2];
211
    } else {
212
      // Q_K_V_BSNH_BNSH_BNSH
213
      if (kvSequenceLength !== value.dims[2]) {
214
        throw new Error('Input "key" and "value" shall have the same dim 2 (kv_sequence_length)');
215
      }
216
      vHiddenSize = value.dims[1] * value.dims[3];
217
      passPastInKv = true;
218
    }
219
  }
220

221
  const broadcastResPosBias = false;
222

223
  if (keyPaddingMask && ShapeUtil.size(keyPaddingMask.dims) > 0) {
224
    throw new Error('Key padding mask is not supported');
225
  }
226

227
  if (attentionBias && ShapeUtil.size(attentionBias.dims) > 0) {
228
    if (attentionBias.dims.length !== 4) {
229
      throw new Error('Input "attention_bias" is expected to have 4 dimensions');
230
    }
231

232
    // TODO: support broadcasting the first and second dimensions of attention_bias.
233
    if (
234
      attentionBias.dims[0] !== batchSize ||
235
      attentionBias.dims[1] !== attributes.numHeads ||
236
      attentionBias.dims[2] !== sequenceLength ||
237
      attentionBias.dims[3] !== totalSequenceLength
238
    ) {
239
      throw new Error('Expect "attention_bias" shape (batch_size, num_heads, sequence_length, total_sequence_length)');
240
    }
241
  }
242

243
  return {
244
    batchSize,
245
    sequenceLength,
246
    pastSequenceLength,
247
    kvSequenceLength,
248
    totalSequenceLength,
249
    maxSequenceLength,
250
    inputHiddenSize: 0,
251
    hiddenSize,
252
    vHiddenSize,
253
    headSize,
254
    vHeadSize: Math.floor(vHiddenSize / attributes.numHeads),
255
    numHeads: attributes.numHeads,
256
    isUnidirectional: false,
257
    pastPresentShareBuffer: false,
258
    maskFilterValue: attributes.maskFilterValue,
259
    maskType,
260
    scale: attributes.scale,
261
    broadcastResPosBias,
262
    passPastInKv,
263
    qkvFormat,
264
  };
265
};
266

267
export const parseMultiHeadAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs =>
268
  createAttributeWithCacheKey({ ...attributes });
269

270
const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({ perm: [0, 2, 1, 3] });
271

272
const addBiasTranspose = (
273
  context: ComputeContext,
274
  qkv: TensorView,
275
  bias: TensorView,
276
  batchSize: number,
277
  sequenceLength: number,
278
  hiddenSize: number,
279
  biasOffset: number,
280
) => {
281
  const outputShape = [batchSize, sequenceLength, hiddenSize];
282
  const outputSize = ShapeUtil.size(outputShape);
283
  const programUniforms: ProgramUniform[] = [
284
    { type: DataType.uint32, data: outputSize },
285
    { type: DataType.uint32, data: biasOffset },
286
    { type: DataType.uint32, data: hiddenSize },
287
  ];
288

289
  const getShaderSource = (shaderHelper: ShaderHelper) => {
290
    const output = outputVariable('qkv_with_bias', qkv.dataType, outputShape);
291
    const qkvInput = inputVariable('qkv', qkv.dataType, outputShape);
292
    const biasInput = inputVariable('bias', bias.dataType, outputShape);
293

294
    const uniforms: UniformsArrayType = [
295
      { name: 'output_size', type: 'u32' },
296
      { name: 'bias_offset', type: 'u32' },
297
      { name: 'hidden_size', type: 'u32' },
298
    ];
299
    return `
300
  ${shaderHelper.registerUniforms(uniforms).declareVariables(qkvInput, biasInput, output)}
301
  ${shaderHelper.mainStart()}
302
    ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
303
    let bias_offset_idx = (global_idx % uniforms.hidden_size) + uniforms.bias_offset;
304

305
    qkv_with_bias[global_idx] = qkv[global_idx] + bias[bias_offset_idx];
306
  }`;
307
  };
308

309
  return context.compute(
310
    {
311
      name: 'MultiHeadAttentionAddBias',
312
      shaderCache: { inputDependencies: ['type', 'type'] },
313
      getRunData: () => ({
314
        outputs: [{ dims: outputShape, dataType: qkv.dataType, gpuDataType: GpuDataType.default }],
315
        dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
316
        programUniforms,
317
      }),
318
      getShaderSource,
319
    },
320
    { inputs: [qkv, bias], outputs: [-1] },
321
  )[0];
322
};
323

324
export const maybeTransposeToBNSHAndAddBias = (
325
  context: ComputeContext,
326
  batchSize: number,
327
  numHeads: number,
328
  sequenceLength: number,
329
  headSize: number,
330
  input: TensorView,
331
  bias?: TensorView,
332
  biasOffset?: number,
333
) => {
334
  // const newDims = [];
335

336
  let reshapedInput = input;
337
  if (!(bias && ShapeUtil.size(bias.dims) > 0)) {
338
    if (input.dims.length === 3) {
339
      reshapedInput = input.reshape([batchSize, sequenceLength, numHeads, headSize]);
340
    }
341
    return context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), {
342
      inputs: [reshapedInput],
343
      outputs: [-1],
344
    })[0];
345
  } else {
346
    if (sequenceLength === 1) {
347
      throw new Error('AddBiasReshape is not implemented. Please export your model with packed QKV or KV');
348
    } else {
349
      reshapedInput = addBiasTranspose(
350
        context,
351
        input,
352
        bias,
353
        batchSize,
354
        sequenceLength,
355
        numHeads * headSize,
356
        biasOffset!,
357
      );
358
      reshapedInput = reshapedInput.reshape([batchSize, sequenceLength, numHeads, headSize]);
359
      return context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), {
360
        inputs: [reshapedInput],
361
        outputs: [-1],
362
      })[0];
363
    }
364
  }
365
};
366

367
export const multiHeadAttention = (context: ComputeContext, attributes: AttentionAttrs): void => {
368
  const params = validateInputs(context.inputs, attributes);
369
  const query = context.inputs[0];
370
  const key = getInput(context.inputs, 1);
371
  const value = getInput(context.inputs, 2);
372
  const bias = getInput(context.inputs, 3);
373
  const keyPaddingMask = getInput(context.inputs, 4);
374
  const attentionBias = getInput(context.inputs, 5);
375
  const pastKey = getInput(context.inputs, 6);
376
  const pastValue = getInput(context.inputs, 7);
377
  if (query.dims.length === 5) {
378
    throw new Error('Packed QKV is not implemented');
379
  }
380

381
  if (key?.dims.length === 5) {
382
    throw new Error('Packed KV is not implemented');
383
  }
384

385
  // applyAttention expects BNSH inputs
386
  const kvBNSH = key && value && key.dims.length === 4 && value.dims.length === 4;
387

388
  const Q = maybeTransposeToBNSHAndAddBias(
389
    context,
390
    params.batchSize,
391
    params.numHeads,
392
    params.sequenceLength,
393
    params.headSize,
394
    query,
395
    bias,
396
    0,
397
  );
398

399
  if (kvBNSH) {
400
    return applyAttention(
401
      context,
402
      Q,
403
      key,
404
      value,
405
      keyPaddingMask,
406
      undefined,
407
      pastKey,
408
      pastValue,
409
      attentionBias,
410
      params,
411
      attributes,
412
    );
413
  }
414
  if (!key || !value) {
415
    throw new Error('key and value must be provided');
416
  }
417
  const K = maybeTransposeToBNSHAndAddBias(
418
    context,
419
    params.batchSize,
420
    params.numHeads,
421
    params.kvSequenceLength,
422
    params.headSize,
423
    key,
424
    bias,
425
    params.hiddenSize,
426
  );
427

428
  const V = maybeTransposeToBNSHAndAddBias(
429
    context,
430
    params.batchSize,
431
    params.numHeads,
432
    params.kvSequenceLength,
433
    params.vHeadSize,
434
    value,
435
    bias,
436
    2 * params.hiddenSize,
437
  );
438

439
  applyAttention(context, Q, K, V, keyPaddingMask, undefined, pastKey, pastValue, attentionBias, params, attributes);
440
};
441

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.