onnxruntime

Форк
0
862 строки · 32.2 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { ComputeContext, GpuDataType, ProgramInputTensorInfoDependency, ProgramUniform } from '../types';
8

9
import {
10
  getMaxComponents,
11
  inputVariable,
12
  outputVariable,
13
  ShaderHelper,
14
  tensorTypeToWsglStorageType,
15
  tensorTypeToWsglValueType,
16
  UniformDataElementType,
17
  UniformsArrayType,
18
} from './common';
19

20
export const enum AttentionQkvFormat {
21
  unknown, // enum value not set, or depends on qkv projection implementation details
22
  qkvBNSH, // for non-packed qkv, permuted
23
  qkvBSNH, // for non-packed qkv, not permuted, used by memory efficient attention or MultiHeadAttention
24
  qkvBSN3H, // for TRT fused attention, qkv are packed
25
  qkvBNSHqkvBS3NH, // for TRT fused causal attention, data has two formats (qkv is 3BNSH, gemm_buffer is BS3NH)
26
  qKvBSNHxBSN2H, // for TRT fused cross attention, kv are packed
27
  qkvTNH, // for memory efficient attention, qkv are not packed, and paddings are removed.
28
  qkvTN3H, // for TRT fused attention, qkv are packed and paddings are removed
29
}
30

31
export const enum AttentionMaskType {
32
  none, // No mask
33
  mask1dKeySeqLen, // [batch_size], key sequence length
34
  mask1dEndStart, // [2 * batch_size] with end positions and start positions
35
  mask1DKeySeqLenStart, // [3 * batch_size + 2] with [key_len[0], ..., key_len[batch_size - 1], query_start[0],
36
  // ..., query_start[batch_size - 1], query_end[batch_size - 1], key_start[0], ...,
37
  // key_start[batch_size - 1], key_end[batch_size - 1]]
38
  mask2dDummy, // dummy mask with shape [1, 1] or [batch_size, 1]. It has same effect as no mask.
39
  mask2dKeyPadding, // [batch_size, total_sequence_length]
40
  mask3dAttention, // [batch_size, sequence_length, total_sequence_length]
41
  mask4dMegatron, // Megatron causal mask with shape [batch_size, 1, max_sequence_length, max_sequence_length]
42
  maskUnknown,
43
}
44

45
export interface AttentionParameters {
46
  batchSize: number;
47
  sequenceLength: number;
48
  pastSequenceLength: number;
49
  kvSequenceLength: number;
50
  totalSequenceLength: number;
51
  maxSequenceLength: number;
52
  inputHiddenSize: number;
53
  hiddenSize: number;
54
  vHiddenSize: number;
55
  headSize: number;
56
  vHeadSize: number;
57
  numHeads: number;
58
  kvNumHeads?: number;
59
  nReps?: number;
60
  isUnidirectional?: boolean;
61
  pastPresentShareBuffer: boolean;
62
  maskFilterValue?: number;
63
  maskType: AttentionMaskType;
64
  scale: number;
65
  broadcastResPosBias: boolean;
66
  passPastInKv: boolean;
67
  qkvFormat: AttentionQkvFormat;
68
  isPastkvBSNH?: boolean;
69
}
70

71
export interface AttentionAttrs {
72
  numHeads: number;
73
  kvNumHeads?: number;
74
  isUnidirectional?: number;
75
  maskFilterValue?: number;
76
  scale: number;
77
  doRotary: number;
78
  qkvHiddenSizes: number[];
79
  pastPresentShareBuffer: boolean;
80
}
81

82
const validateAttentionInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => {
83
  // Abbreviation and Meanings:
84
  //   B:    batch_size
85
  //   S:    sequence_length (input sequence length of query)
86
  //   P:    past_sequence_length (past sequence length of key or value)
87
  //   L:    kv_sequence_length (input sequence length of key or value)
88
  //   M:    max_sequence_length
89
  //   T:    total_sequence_length = past_sequence_length + kv_sequence_length
90
  //   N:    num_heads
91
  //   H:    head size for Q and K, aka q_head_size or k_head_size or qk_head_size
92
  //   H_v:  v_head_size
93
  //   D_i:  input hidden size
94
  //   D:    hidden size for Q and K (D = N * H), aka q_hidden_size or k_hidden_size or qk_hidden_size
95
  //   D_v:  v_hidden_size = num_heads * v_head_size
96

97
  // When past state is used, Q, K and V should have same hidden size (unless we split it into past_key and past_value).
98

99
  // Input shapes:
100
  //   input        (Q/K/V)    : (B, S, D_i)
101
  //   weights      (Q/K/V)    : (D_i, D + D + D_v)
102
  //   bias         (Q/K/V)    : (D + D + D_v)
103
  //   mask_index              : see below
104
  //   past         (K/V)      : (2, B, N, P, H) or NULL
105
  //   attention_bias          : (B, N, S, T) or NULL
106

107
  // For mask_index, the following shapes are supported:
108
  //     NULL, (B, 1), (1, 1)
109
  //     (B), (2 * B), (3 * B + 2)
110
  //     (B, T)
111
  //     (B, S, T)
112
  //     (B, 1, M, M)
113
  //
114
  // When a model is pruned (like some attention heads are removed in Q/K/V), input_hidden_size could be larger
115
  // than hidden dimension of Q, K and V.
116

117
  const input = inputs[0];
118
  const weights = inputs[1];
119
  const bias = inputs[2];
120
  const maskIndex = inputs[3];
121
  const past = inputs[4];
122
  const attentionBias = inputs[5];
123

124
  if (past && attentionBias) {
125
    throw new Error('Attention cannot have both past and attention_bias');
126
  }
127

128
  if (input.dims.length !== 3) {
129
    throw new Error('Input "input" must have 3 dimensions');
130
  }
131

132
  const batchSize = input.dims[0];
133
  const sequenceLength = input.dims[1];
134
  const inputHiddenSize = input.dims[2];
135

136
  if (bias.dims.length !== 1) {
137
    throw new Error('Input "bias" is expected to have 1 dimensions');
138
  }
139

140
  if (weights.dims.length !== 2) {
141
    throw new Error('Input "weights" is expected to have 2 dimensions');
142
  }
143

144
  if (weights.dims[0] !== inputHiddenSize) {
145
    throw new Error('Input 1 dimension 0 should have same length as dimension 2 of input 0');
146
  }
147

148
  if (bias.dims[0] !== weights.dims[1]) {
149
    throw new Error('Input "bias" dimension 0 should have same length as dimension 1 of input "weights"');
150
  }
151

152
  let qHiddenSize = bias.dims[0] / 3;
153
  let kHiddenSize = qHiddenSize;
154
  let vHiddenSize = kHiddenSize;
155
  if (attributes.qkvHiddenSizes.length > 0) {
156
    if (attributes.qkvHiddenSizes.length !== 3) {
157
      throw new Error('qkv_hidden_sizes attribute should have 3 elements');
158
    }
159
    for (const sz of attributes.qkvHiddenSizes) {
160
      if (sz % attributes.numHeads !== 0) {
161
        throw new Error('qkv_hidden_sizes should be divisible by num_heads');
162
      }
163
    }
164

165
    qHiddenSize = attributes.qkvHiddenSizes[0];
166
    kHiddenSize = attributes.qkvHiddenSizes[1];
167
    vHiddenSize = attributes.qkvHiddenSizes[2];
168
  }
169

170
  const kvSequenceLength = sequenceLength;
171

172
  if (qHiddenSize !== kHiddenSize) {
173
    throw new Error('qkv_hidden_sizes first element should be same as the second');
174
  }
175

176
  if (bias.dims[0] !== qHiddenSize + kHiddenSize + vHiddenSize) {
177
    throw new Error('Input "bias" dimension 0 should have same length as sum of Q/K/V hidden sizes');
178
  }
179

180
  let pastSequenceLength = 0;
181
  if (past) {
182
    if (kHiddenSize !== vHiddenSize) {
183
      throw new Error('Input "past" expect k_hidden_size == v_hidden_size');
184
    }
185
    if (past.dims.length !== 5) {
186
      throw new Error('Input "past" must have 5 dimensions');
187
    }
188
    if (past.dims[0] !== 2) {
189
      throw new Error('Input "past" first dimension must be 2');
190
    }
191
    if (past.dims[1] !== batchSize) {
192
      throw new Error('Input "past" second dimension must be batch_size');
193
    }
194
    if (past.dims[2] !== attributes.numHeads) {
195
      throw new Error('Input "past" third dimension must be num_heads');
196
    }
197
    if (past.dims[4] !== kHiddenSize / attributes.numHeads) {
198
      throw new Error('Input "past" fifth dimension must be k_hidden_size / num_heads');
199
    }
200

201
    if (!attributes.pastPresentShareBuffer) {
202
      pastSequenceLength = past.dims[3];
203
    }
204
    // TODO: handle past_seq_len
205
  }
206

207
  const totalSequenceLength = kvSequenceLength + pastSequenceLength;
208
  const maxSequenceLength = -1;
209

210
  const maskType = AttentionMaskType.none;
211
  if (maskIndex) {
212
    // maskType = AttentionMaskType.MASK_UNKNOWN;
213
    // TODO: handle mask
214
    throw new Error('Mask not supported');
215
  }
216

217
  if (past) {
218
    throw new Error('past is not supported');
219
  }
220

221
  if (attentionBias) {
222
    if (attentionBias.dims.length !== 4) {
223
      throw new Error('Input "attention_bias" must have 4 dimensions');
224
    }
225

226
    // TODO: support broadcasting the first and second dimensions of attention_bias
227
    if (
228
      attentionBias.dims[0] !== batchSize ||
229
      attentionBias.dims[1] !== attributes.numHeads ||
230
      attentionBias.dims[2] !== sequenceLength ||
231
      attentionBias.dims[3] !== totalSequenceLength
232
    ) {
233
      throw new Error('Expect "attention_bias" shape (batch_size, num_heads, sequence_length, total_sequence_length)');
234
    }
235
  }
236

237
  return {
238
    batchSize,
239
    sequenceLength,
240
    pastSequenceLength,
241
    kvSequenceLength,
242
    totalSequenceLength,
243
    maxSequenceLength,
244
    inputHiddenSize,
245
    hiddenSize: qHiddenSize,
246
    vHiddenSize,
247
    headSize: Math.floor(qHiddenSize / attributes.numHeads),
248
    vHeadSize: Math.floor(vHiddenSize / attributes.numHeads),
249
    numHeads: attributes.numHeads,
250
    isUnidirectional: false,
251
    pastPresentShareBuffer: false,
252
    maskFilterValue: attributes.maskFilterValue,
253
    maskType,
254
    scale: attributes.scale,
255
    broadcastResPosBias: false,
256
    passPastInKv: false,
257
    qkvFormat: AttentionQkvFormat.qkvBNSH,
258
  };
259
};
260

261
const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number) => {
262
  const components = getMaxComponents(d);
263
  let WG = 64;
264
  const dComp = d / components;
265
  if (dComp < WG) {
266
    WG = 32;
267
  }
268
  const elementsPerThread = Math.ceil(d / components / WG);
269
  const programUniforms: ProgramUniform[] = [
270
    { type: DataType.float, data: 1 / d },
271
    { type: DataType.uint32, data: dComp },
272
    { type: DataType.uint32, data: elementsPerThread },
273
  ];
274
  const dataType = tensorTypeToWsglStorageType(input.dataType, components);
275
  const f32Type = tensorTypeToWsglValueType(DataType.float, components);
276
  const inputDependencies: ProgramInputTensorInfoDependency[] = ['type'];
277
  const getShaderSource = (shaderHelper: ShaderHelper) => {
278
    const inputHelper = outputVariable('x', input.dataType, input.dims, components);
279
    const elemValueType = tensorTypeToWsglValueType(input.dataType);
280
    const uniforms: UniformsArrayType = [
281
      { name: 'd_inv', type: 'f32' },
282
      { name: 'd_comp', type: 'u32' },
283
      { name: 'elements_per_thread', type: 'u32' },
284
    ];
285

286
    return `
287
  var<workgroup> thread_max: array<f32, ${WG}>;
288
  var<workgroup> thread_sum: array<f32, ${WG}>;
289
  ${shaderHelper.registerUniforms(uniforms).declareVariables(inputHelper)}
290
  ${shaderHelper.mainStart([WG, 1, 1])}
291
    let local_offset = local_idx * uniforms.elements_per_thread;
292
    let offset = (global_idx / ${WG}) * uniforms.d_comp + local_offset;
293

294
    var thread_max_vector = ${f32Type}(-3.402823e+38f);
295
    for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
296
      thread_max_vector = max(${f32Type}(x[offset + i]), thread_max_vector);
297
    }
298
    thread_max[local_idx] = ${(() => {
299
      switch (components) {
300
        case 1:
301
          return 'thread_max_vector';
302
        case 2:
303
          return 'max(thread_max_vector.x, thread_max_vector.y)';
304
        case 4:
305
          return 'max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))';
306
        default:
307
          throw new Error(`Unsupported components: ${components}`);
308
      }
309
    })()};
310
    workgroupBarrier();
311

312
    var max_value =  f32(-3.402823e+38f);
313
    for (var i = 0u; i < ${WG}; i++) {
314
      max_value = max(thread_max[i], max_value);
315
    }
316

317
    var sum_vector = ${f32Type}(0);
318
    for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
319
      sum_vector += exp(${f32Type}(x[offset + i]) - max_value);
320
    }
321
    thread_sum[local_idx] = ${(() => {
322
      switch (components) {
323
        case 1:
324
          return 'sum_vector';
325
        case 2:
326
          return 'sum_vector.x + sum_vector.y';
327
        case 4:
328
          return 'sum_vector.x + sum_vector.y + sum_vector.z + sum_vector.w';
329
        default:
330
          throw new Error(`Unsupported components: ${components}`);
331
      }
332
    })()};
333
    workgroupBarrier();
334

335
    var sum: f32 = 0;
336
    for (var i = 0u; i < ${WG}; i++) {
337
      sum += thread_sum[i];
338
    }
339

340
    if (sum == 0) {
341
      for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
342
        x[offset + i] = ${inputHelper.type.value}(${elemValueType}(uniforms.d_inv));
343
      }
344
    } else {
345
      for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
346
        var f32input = ${f32Type}(x[offset + i]);
347
        x[offset + i] = ${inputHelper.type.value}(exp(f32input - max_value) / sum);
348
      }
349
    }
350
  }`;
351
  };
352

353
  return {
354
    name: 'AttentionProbsSoftmax',
355
    shaderCache: { hint: `${WG};${dataType};${components}`, inputDependencies },
356
    getShaderSource,
357
    getRunData: () => ({ outputs: [], dispatchGroup: { x: n }, programUniforms }),
358
  };
359
};
360

361
const createAttentionProbsProgramInfo = (
362
  outputCount: number,
363
  q: TensorView,
364
  key: TensorView,
365
  pastKey: TensorView | undefined,
366
  attentionBias: TensorView | undefined,
367
  parameters: AttentionParameters,
368
  attributes: AttentionAttrs,
369
  pastSequenceLength: number,
370
) => {
371
  const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
372
  const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength];
373
  const presentKey = parameters.kvNumHeads === undefined && outputCount > 1 && pastKey;
374
  const presentKeyShape = presentKey
375
    ? [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize]
376
    : undefined;
377

378
  // TODO: handle mask
379

380
  const alpha = attributes.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : attributes.scale;
381
  const components = getMaxComponents(parameters.headSize);
382
  const vectorizedHeadSize = parameters.headSize / components;
383
  const TILE_SIZE = 12;
384
  const dispatch = {
385
    x: Math.ceil(totalSequenceLength / TILE_SIZE),
386
    y: Math.ceil(parameters.sequenceLength / TILE_SIZE),
387
    z: parameters.batchSize * parameters.numHeads,
388
  };
389
  const programUniforms: ProgramUniform[] = [
390
    { type: DataType.uint32, data: parameters.sequenceLength },
391
    { type: DataType.uint32, data: vectorizedHeadSize },
392
    { type: DataType.uint32, data: totalSequenceLength },
393
    { type: DataType.uint32, data: parameters.numHeads },
394
    { type: DataType.float, data: alpha },
395
    { type: DataType.uint32, data: pastSequenceLength },
396
    { type: DataType.uint32, data: parameters.kvSequenceLength },
397
  ];
398
  // Feed pastKey to the shader-code only if it is non-zero and presentKey is being produced
399
  const feedPastKey = presentKey && pastKey && ShapeUtil.size(pastKey.dims) > 0;
400
  const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
401
  if (feedPastKey) {
402
    inputDependencies.push('type');
403
  }
404
  if (attentionBias) {
405
    inputDependencies.push('type');
406
  }
407
  const outputs = [{ dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default }];
408
  if (presentKey) {
409
    outputs.push({ dims: presentKeyShape!, dataType: q.dataType, gpuDataType: GpuDataType.default });
410
  }
411
  const getShaderSource = (shaderHelper: ShaderHelper) => {
412
    const qInput = inputVariable('q', q.dataType, q.dims, components);
413
    const kInput = inputVariable('key', key.dataType, key.dims, components);
414
    const inputVars = [qInput, kInput];
415
    if (feedPastKey) {
416
      const pastKeyInput = inputVariable('past_key', pastKey.dataType, pastKey.dims, components);
417
      inputVars.push(pastKeyInput);
418
    }
419
    if (attentionBias) {
420
      inputVars.push(inputVariable('attention_bias', attentionBias.dataType, attentionBias.dims));
421
    }
422
    const output = outputVariable('output', q.dataType, probsShape);
423
    const outputVars = [output];
424
    if (presentKey) {
425
      outputVars.push(outputVariable('present_key', q.dataType, presentKeyShape!, components));
426
    }
427
    const f32Type = tensorTypeToWsglValueType(DataType.float, components);
428

429
    const uniforms: UniformsArrayType = [
430
      { name: 'M', type: 'u32' },
431
      { name: 'K', type: 'u32' },
432
      { name: 'N', type: 'u32' },
433
      { name: 'num_heads', type: 'u32' },
434
      { name: 'alpha', type: 'f32' as UniformDataElementType },
435
      { name: 'past_sequence_length', type: 'u32' },
436
      { name: 'kv_sequence_length', type: 'u32' },
437
    ];
438
    return `
439
  const TILE_SIZE = ${TILE_SIZE}u;
440

441
  var<workgroup> tileQ: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>;
442
  var<workgroup> tileK: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>;
443
  ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, ...outputVars)}
444
  ${shaderHelper.mainStart([TILE_SIZE, TILE_SIZE, 1])}
445
    // x holds the N and y holds the M
446
    let headIdx = workgroup_id.z;
447
    let m = workgroup_id.y * TILE_SIZE;
448
    let n = workgroup_id.x * TILE_SIZE;
449
    let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K;
450
    ${(() => {
451
      if (feedPastKey && presentKey) {
452
        return `
453
    let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx;
454
    let pastKeyOffset = uniforms.past_sequence_length * uniforms.K * headIdx;`;
455
      } else {
456
        return `
457
    let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;`;
458
      }
459
    })()}
460
    ${presentKey ? 'let presentKeyOffset = headIdx * uniforms.N * uniforms.K;' : ''}
461
    var value = ${f32Type}(0);
462
    for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
463
      if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {
464
        tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];
465
      }
466
      if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {
467
        var idx = TILE_SIZE * local_id.y + local_id.x;
468
      ${(() => {
469
        if (feedPastKey && presentKey) {
470
          return `
471
              if (n + local_id.y < uniforms.past_sequence_length) {
472
                tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
473
              } else {
474
                tileK[idx] =
475
                         key[kOffset + (n + local_id.y - uniforms.past_sequence_length) * uniforms.K + w + local_id.x];
476
              }`;
477
        } else {
478
          return 'tileK[idx] = key[kOffset + local_id.y * uniforms.K + w + local_id.x];';
479
        }
480
      })()}
481
      ${
482
        presentKey ? 'present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];' : ''
483
      }
484
      }
485
      workgroupBarrier();
486

487
      for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {
488
        value += ${f32Type}(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]);
489
      }
490

491
      workgroupBarrier();
492
    }
493

494
    let headOffset = headIdx * uniforms.M * uniforms.N;
495
    if (global_id.y < uniforms.M && global_id.x < uniforms.N) {
496
      let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;
497
      var sum: f32 = ${(() => {
498
        switch (components) {
499
          case 1:
500
            return 'value';
501
          case 2:
502
            return 'value.x + value.y';
503
          case 4:
504
            return 'value.x + value.y + value.z + value.w';
505
          default:
506
            throw new Error(`Unsupported components: ${components}`);
507
        }
508
      })()};
509
        output[outputIdx] = ${output.type.value} (sum * uniforms.alpha) + ${
510
          attentionBias ? 'attention_bias[outputIdx]' : '0.0'
511
        };
512
    }
513
  }`;
514
  };
515
  return {
516
    name: 'AttentionProbs',
517
    shaderCache: {
518
      hint: `${components};${attentionBias !== undefined};${pastKey !== undefined};${outputCount}`,
519
      inputDependencies,
520
    },
521
    getRunData: () => ({ outputs, dispatchGroup: dispatch, programUniforms }),
522
    getShaderSource,
523
  };
524
};
525

526
const createVxAttentionScoreProgramInfo = (
527
  outputCount: number,
528
  probs: TensorView,
529
  v: TensorView,
530
  pastValue: TensorView | undefined,
531
  params: AttentionParameters,
532
  pastSequenceLength: number,
533
) => {
534
  const totalSequenceLength = pastSequenceLength + params.kvSequenceLength;
535
  const nReps = params.nReps ? params.nReps : 1;
536
  const repeatedVHiddenSize = params.vHiddenSize * nReps;
537
  const presentValue = params.kvNumHeads == null && outputCount > 1 && pastValue;
538
  const presentValueShape = presentValue
539
    ? [params.batchSize, params.numHeads, totalSequenceLength, params.headSize]
540
    : undefined;
541
  const outputShape = [params.batchSize, params.sequenceLength, repeatedVHiddenSize];
542
  const TILE_SIZE = 12;
543
  const dispatch = {
544
    x: Math.ceil(params.vHeadSize / TILE_SIZE),
545
    y: Math.ceil(params.sequenceLength / TILE_SIZE),
546
    z: params.batchSize * params.numHeads,
547
  };
548

549
  const programUniforms: ProgramUniform[] = [
550
    { type: DataType.uint32, data: params.sequenceLength },
551
    { type: DataType.uint32, data: totalSequenceLength },
552
    { type: DataType.uint32, data: params.vHeadSize },
553
    { type: DataType.uint32, data: params.numHeads },
554
    { type: DataType.uint32, data: repeatedVHiddenSize },
555
    { type: DataType.uint32, data: pastSequenceLength },
556
    { type: DataType.uint32, data: params.kvSequenceLength },
557
  ];
558
  // Feed pastValue to the shader-code only if it is non-empty and presentValue is being produced
559
  const feedPastValue = presentValue && pastValue && ShapeUtil.size(pastValue.dims) > 0;
560
  const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
561
  if (feedPastValue) {
562
    inputDependencies.push('type');
563
  }
564
  const outputs = [{ dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default }];
565
  if (presentValue) {
566
    outputs.push({ dims: presentValueShape!, dataType: probs.dataType, gpuDataType: GpuDataType.default });
567
  }
568
  const getShaderSource = (shaderHelper: ShaderHelper) => {
569
    const probsHelper = inputVariable('probs', probs.dataType, probs.dims);
570
    const vHelper = inputVariable('v', v.dataType, v.dims);
571
    const inputVars = [probsHelper, vHelper];
572
    if (feedPastValue) {
573
      inputVars.push(inputVariable('past_value', pastValue.dataType, pastValue.dims));
574
    }
575
    const output = outputVariable('output', probs.dataType, outputShape);
576
    const outputVars = [output];
577
    if (presentValue) {
578
      outputVars.push(outputVariable('present_value', probs.dataType, presentValueShape!));
579
    }
580
    const uniforms: UniformsArrayType = [
581
      { name: 'M', type: 'u32' },
582
      { name: 'K', type: 'u32' },
583
      { name: 'N', type: 'u32' },
584
      { name: 'num_heads', type: 'u32' },
585
      { name: 'v_hidden_size', type: 'u32' },
586
      { name: 'past_sequence_length', type: 'u32' },
587
      { name: 'kv_sequence_length', type: 'u32' },
588
    ];
589
    return `
590
  const TILE_SIZE = ${TILE_SIZE}u;
591
  var<workgroup> tileQ: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>;
592
  var<workgroup> tileK: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>;
593
  ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, ...outputVars)}
594
  ${shaderHelper.mainStart([TILE_SIZE, TILE_SIZE, 1])}
595
   let headIdx = workgroup_id.z;
596
   let m = global_id.y;
597
   let n = global_id.x;
598

599
   let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K;
600
   ${(() => {
601
     if (feedPastValue && presentValue) {
602
       return `
603
    let pastValueOffset = headIdx * uniforms.N * uniforms.past_sequence_length + n;
604
    let vOffset = headIdx * uniforms.N * uniforms.kv_sequence_length + n;
605
      `;
606
     } else {
607
       return `
608
   let offsetB = headIdx * uniforms.N * uniforms.K + n;
609
            `;
610
     }
611
   })()}
612
    ${presentValue ? 'let presentValueOffset = headIdx * uniforms.N * uniforms.K + n;' : ''}
613
   var value = ${probsHelper.type.storage}(0);
614
   for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
615
      if (m < uniforms.M && w + local_id.x < uniforms.K) {
616
        tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x];
617
      }
618
      if (n < uniforms.N && w + local_id.y < uniforms.K) {
619
        var idx = TILE_SIZE * local_id.y + local_id.x;
620
        ${(() => {
621
          if (feedPastValue && presentValue) {
622
            return `
623
        if (w + local_id.y < uniforms.past_sequence_length) {
624
          tileK[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N];
625
        } else {
626
          tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N];
627
        }
628
      `;
629
          } else {
630
            return `
631
        tileK[idx] = v[offsetB + (w + local_id.y) * uniforms.N];
632
      `;
633
          }
634
        })()}
635
        ${presentValue ? 'present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];' : ''}
636
      }
637
     workgroupBarrier();
638
     for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {
639
       value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x];
640
     }
641
     workgroupBarrier();
642
   }
643

644
   // we need to transpose output from BNSH_v to BSND_v
645
   let batchIdx = workgroup_id.z / uniforms.num_heads;
646
   let currentBatchHeadNumber = workgroup_id.z % uniforms.num_heads;
647
   if (m < uniforms.M && n < uniforms.N) {
648
     let outputIdx = batchIdx * uniforms.M * uniforms.v_hidden_size + m * uniforms.v_hidden_size
649
       + currentBatchHeadNumber * uniforms.N + n;
650
     output[outputIdx] = value;
651
   }
652
  }`;
653
  };
654

655
  return {
656
    name: 'AttentionScore',
657
    shaderCache: { hint: `${pastValue !== undefined};${outputCount}`, inputDependencies },
658
    getRunData: () => ({ outputs, dispatchGroup: dispatch, programUniforms }),
659
    getShaderSource,
660
  };
661
};
662

663
export const applyAttention = (
664
  context: ComputeContext,
665
  q: TensorView,
666
  k: TensorView,
667
  v: TensorView,
668
  _maskIndex: TensorView | undefined,
669
  _past: TensorView | undefined,
670
  pastKey: TensorView | undefined,
671
  pastValue: TensorView | undefined,
672
  attentionBiasInput: TensorView | undefined,
673
  parameters: AttentionParameters,
674
  attributes: AttentionAttrs,
675
) => {
676
  // Assumption  is that presentKey/presentValue exists only if pastKey/pastValue exists.
677
  const outputCount = Math.min(context.outputCount, 1 + (pastKey ? 1 : 0) + (pastValue ? 1 : 0));
678
  const pastSequenceLength = parameters.kvNumHeads !== undefined || outputCount > 1 ? parameters.pastSequenceLength : 0;
679
  const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
680
  const attentionBias =
681
    attentionBiasInput && ShapeUtil.size(attentionBiasInput.dims) > 0 ? attentionBiasInput : undefined;
682

683
  const inputsK = [q, k];
684
  if (parameters.kvNumHeads === undefined && outputCount > 1 && pastKey && ShapeUtil.size(pastKey.dims) > 0) {
685
    inputsK.push(pastKey);
686
  }
687
  if (attentionBias) {
688
    inputsK.push(attentionBias);
689
  }
690

691
  // Run AttentionProbs
692
  const probs = context.compute(
693
    createAttentionProbsProgramInfo(
694
      outputCount,
695
      q,
696
      k,
697
      pastKey,
698
      attentionBias,
699
      parameters,
700
      attributes,
701
      pastSequenceLength,
702
    ),
703
    { inputs: inputsK, outputs: parameters.kvNumHeads === undefined && outputCount > 1 ? [-1, 1] : [-1] },
704
  )[0];
705

706
  // Run Softmax
707
  context.compute(
708
    createInPlaceSoftmaxProgramInfo(
709
      probs,
710
      parameters.batchSize * parameters.numHeads * parameters.sequenceLength,
711
      totalSequenceLength,
712
    ),
713
    { inputs: [probs], outputs: [] },
714
  );
715

716
  // Run AttrionScore
717
  const inputsV = [probs, v];
718
  if (parameters.kvNumHeads === undefined && outputCount > 1 && pastValue && ShapeUtil.size(pastValue.dims) > 0) {
719
    inputsV.push(pastValue);
720
  }
721
  context.compute(createVxAttentionScoreProgramInfo(outputCount, probs, v, pastValue, parameters, pastSequenceLength), {
722
    inputs: inputsV,
723
    outputs: parameters.kvNumHeads === undefined && outputCount > 1 ? [0, 2] : [0],
724
  });
725
};
726

727
const prepare = (context: ComputeContext, parameters: AttentionParameters) => {
728
  const outputShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, parameters.headSize];
729
  const M = parameters.sequenceLength;
730
  const K = parameters.inputHiddenSize;
731
  const N = parameters.headSize;
732
  const TILE_SIZE = 12;
733
  const dispatch = {
734
    x: Math.ceil(parameters.headSize / TILE_SIZE),
735
    y: Math.ceil(parameters.sequenceLength / TILE_SIZE),
736
    z: parameters.batchSize * parameters.numHeads,
737
  };
738
  const inputs = [context.inputs[0], context.inputs[1], context.inputs[2]];
739
  const programUniforms: ProgramUniform[] = [
740
    { type: DataType.uint32, data: M },
741
    { type: DataType.uint32, data: K },
742
    { type: DataType.uint32, data: N },
743
    { type: DataType.uint32, data: parameters.numHeads },
744
    { type: DataType.uint32, data: parameters.headSize },
745
    { type: DataType.uint32, data: parameters.hiddenSize },
746
    { type: DataType.uint32, data: parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize },
747
  ];
748

749
  const getShaderSource = (shaderHelper: ShaderHelper) => {
750
    const outputQ = outputVariable('output_q', inputs[0].dataType, outputShape);
751
    const outputK = outputVariable('output_k', inputs[0].dataType, outputShape);
752
    const outputV = outputVariable('output_v', inputs[0].dataType, outputShape);
753
    const input = inputVariable('input', inputs[0].dataType, inputs[0].dims);
754
    const weight = inputVariable('weight', inputs[1].dataType, inputs[1].dims);
755
    const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims);
756
    const dataType = input.type.storage;
757

758
    const uniforms: UniformsArrayType = [
759
      { name: 'M', type: 'u32' },
760
      { name: 'K', type: 'u32' },
761
      { name: 'N', type: 'u32' },
762
      { name: 'num_heads', type: 'u32' },
763
      { name: 'head_size', type: 'u32' },
764
      { name: 'hidden_size', type: 'u32' },
765
      { name: 'ldb', type: 'u32' },
766
    ];
767
    return `
768
  const TILE_SIZE = ${TILE_SIZE}u;
769
  var<workgroup> tileInput: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>;
770
  var<workgroup> tileWeightQ: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>;
771
  var<workgroup> tileWeightK: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>;
772
  var<workgroup> tileWeightV: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>;
773
  ${shaderHelper.registerUniforms(uniforms).declareVariables(input, weight, bias, outputQ, outputK, outputV)}
774
  ${shaderHelper.mainStart([TILE_SIZE, TILE_SIZE, 1])}
775
    let batchIndex = workgroup_id.z / uniforms.num_heads;
776
    let headNumber = workgroup_id.z % uniforms.num_heads;
777
    let m = global_id.y;
778
    let n = global_id.x;
779

780
    let inputOffset = batchIndex * (uniforms.M * uniforms.K) + m * uniforms.K;
781
    let biasOffsetQ = headNumber * uniforms.head_size;
782
    let biasOffsetK = uniforms.hidden_size + biasOffsetQ;
783
    let biasOffsetV = uniforms.hidden_size + biasOffsetK;
784

785
    var valueQ = ${dataType}(0);
786
    var valueK = ${dataType}(0);
787
    var valueV = ${dataType}(0);
788
    for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
789
      if (m < uniforms.M && w + local_id.x < uniforms.K) {
790
        tileInput[TILE_SIZE * local_id.y + local_id.x] = input[inputOffset + w + local_id.x];
791
      }
792
      if (n < uniforms.N && w + local_id.y < uniforms.K) {
793
        let offset = n + (w + local_id.y) * uniforms.ldb;
794
        tileWeightQ[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetQ + offset];
795
        tileWeightK[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetK + offset];
796
        tileWeightV[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetV + offset];
797
      }
798
      workgroupBarrier();
799
      for (var k: u32 = 0u; k<TILE_SIZE && w+k < uniforms.K; k++) {
800
        let inputTileOffset = TILE_SIZE * local_id.y + k;
801
        let weightTileOffset = TILE_SIZE * k + local_id.x;
802
        valueQ += tileInput[inputTileOffset] * tileWeightQ[weightTileOffset];
803
        valueK += tileInput[inputTileOffset] * tileWeightK[weightTileOffset];
804
        valueV += tileInput[inputTileOffset] * tileWeightV[weightTileOffset];
805
      }
806

807
      workgroupBarrier();
808
    }
809

810
    let headOffset = (m * uniforms.N + n) % uniforms.head_size;
811
    valueQ += bias[headOffset + biasOffsetQ];
812
    valueK += bias[headOffset + biasOffsetK];
813
    valueV += bias[headOffset + biasOffsetV];
814

815
    let offset = workgroup_id.z * uniforms.M * uniforms.N;
816
    if (m < uniforms.M && n < uniforms.N) {
817
      let outputIdx = offset + m * uniforms.N + n;
818
      output_q[outputIdx] = valueQ;
819
      output_k[outputIdx] = valueK;
820
      output_v[outputIdx] = valueV;
821
    }
822
  }`;
823
  };
824

825
  return context.compute(
826
    {
827
      name: 'AttentionPrepare',
828
      shaderCache: { inputDependencies: ['type', 'type', 'type'] },
829
      getRunData: () => ({
830
        outputs: [
831
          { dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default },
832
          { dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default },
833
          { dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default },
834
        ],
835
        dispatchGroup: dispatch,
836
        programUniforms,
837
      }),
838
      getShaderSource,
839
    },
840
    { inputs, outputs: [-1, -1, -1] },
841
  );
842
};
843

844
export const attention = (context: ComputeContext, attributes: AttentionAttrs): void => {
845
  const params = validateAttentionInputs(context.inputs, attributes);
846

847
  const [q, k, v] = prepare(context, params);
848

849
  return applyAttention(
850
    context,
851
    q,
852
    k,
853
    v,
854
    context.inputs[4],
855
    undefined,
856
    undefined,
857
    undefined,
858
    context.inputs[5],
859
    params,
860
    attributes,
861
  );
862
};
863

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.