1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { ComputeContext, GpuDataType, ProgramInputTensorInfoDependency, ProgramUniform } from '../types';
14
tensorTypeToWsglStorageType,
15
tensorTypeToWsglValueType,
16
UniformDataElementType,
20
export const enum AttentionQkvFormat {
21
unknown, // enum value not set, or depends on qkv projection implementation details
22
qkvBNSH, // for non-packed qkv, permuted
23
qkvBSNH, // for non-packed qkv, not permuted, used by memory efficient attention or MultiHeadAttention
24
qkvBSN3H, // for TRT fused attention, qkv are packed
25
qkvBNSHqkvBS3NH, // for TRT fused causal attention, data has two formats (qkv is 3BNSH, gemm_buffer is BS3NH)
26
qKvBSNHxBSN2H, // for TRT fused cross attention, kv are packed
27
qkvTNH, // for memory efficient attention, qkv are not packed, and paddings are removed.
28
qkvTN3H, // for TRT fused attention, qkv are packed and paddings are removed
31
export const enum AttentionMaskType {
33
mask1dKeySeqLen, // [batch_size], key sequence length
34
mask1dEndStart, // [2 * batch_size] with end positions and start positions
35
mask1DKeySeqLenStart, // [3 * batch_size + 2] with [key_len[0], ..., key_len[batch_size - 1], query_start[0],
36
// ..., query_start[batch_size - 1], query_end[batch_size - 1], key_start[0], ...,
37
// key_start[batch_size - 1], key_end[batch_size - 1]]
38
mask2dDummy, // dummy mask with shape [1, 1] or [batch_size, 1]. It has same effect as no mask.
39
mask2dKeyPadding, // [batch_size, total_sequence_length]
40
mask3dAttention, // [batch_size, sequence_length, total_sequence_length]
41
mask4dMegatron, // Megatron causal mask with shape [batch_size, 1, max_sequence_length, max_sequence_length]
45
export interface AttentionParameters {
47
sequenceLength: number;
48
pastSequenceLength: number;
49
kvSequenceLength: number;
50
totalSequenceLength: number;
51
maxSequenceLength: number;
52
inputHiddenSize: number;
60
isUnidirectional?: boolean;
61
pastPresentShareBuffer: boolean;
62
maskFilterValue?: number;
63
maskType: AttentionMaskType;
65
broadcastResPosBias: boolean;
66
passPastInKv: boolean;
67
qkvFormat: AttentionQkvFormat;
68
isPastkvBSNH?: boolean;
71
export interface AttentionAttrs {
74
isUnidirectional?: number;
75
maskFilterValue?: number;
78
qkvHiddenSizes: number[];
79
pastPresentShareBuffer: boolean;
82
const validateAttentionInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => {
83
// Abbreviation and Meanings:
85
// S: sequence_length (input sequence length of query)
86
// P: past_sequence_length (past sequence length of key or value)
87
// L: kv_sequence_length (input sequence length of key or value)
88
// M: max_sequence_length
89
// T: total_sequence_length = past_sequence_length + kv_sequence_length
91
// H: head size for Q and K, aka q_head_size or k_head_size or qk_head_size
93
// D_i: input hidden size
94
// D: hidden size for Q and K (D = N * H), aka q_hidden_size or k_hidden_size or qk_hidden_size
95
// D_v: v_hidden_size = num_heads * v_head_size
97
// When past state is used, Q, K and V should have same hidden size (unless we split it into past_key and past_value).
100
// input (Q/K/V) : (B, S, D_i)
101
// weights (Q/K/V) : (D_i, D + D + D_v)
102
// bias (Q/K/V) : (D + D + D_v)
103
// mask_index : see below
104
// past (K/V) : (2, B, N, P, H) or NULL
105
// attention_bias : (B, N, S, T) or NULL
107
// For mask_index, the following shapes are supported:
108
// NULL, (B, 1), (1, 1)
109
// (B), (2 * B), (3 * B + 2)
114
// When a model is pruned (like some attention heads are removed in Q/K/V), input_hidden_size could be larger
115
// than hidden dimension of Q, K and V.
117
const input = inputs[0];
118
const weights = inputs[1];
119
const bias = inputs[2];
120
const maskIndex = inputs[3];
121
const past = inputs[4];
122
const attentionBias = inputs[5];
124
if (past && attentionBias) {
125
throw new Error('Attention cannot have both past and attention_bias');
128
if (input.dims.length !== 3) {
129
throw new Error('Input "input" must have 3 dimensions');
132
const batchSize = input.dims[0];
133
const sequenceLength = input.dims[1];
134
const inputHiddenSize = input.dims[2];
136
if (bias.dims.length !== 1) {
137
throw new Error('Input "bias" is expected to have 1 dimensions');
140
if (weights.dims.length !== 2) {
141
throw new Error('Input "weights" is expected to have 2 dimensions');
144
if (weights.dims[0] !== inputHiddenSize) {
145
throw new Error('Input 1 dimension 0 should have same length as dimension 2 of input 0');
148
if (bias.dims[0] !== weights.dims[1]) {
149
throw new Error('Input "bias" dimension 0 should have same length as dimension 1 of input "weights"');
152
let qHiddenSize = bias.dims[0] / 3;
153
let kHiddenSize = qHiddenSize;
154
let vHiddenSize = kHiddenSize;
155
if (attributes.qkvHiddenSizes.length > 0) {
156
if (attributes.qkvHiddenSizes.length !== 3) {
157
throw new Error('qkv_hidden_sizes attribute should have 3 elements');
159
for (const sz of attributes.qkvHiddenSizes) {
160
if (sz % attributes.numHeads !== 0) {
161
throw new Error('qkv_hidden_sizes should be divisible by num_heads');
165
qHiddenSize = attributes.qkvHiddenSizes[0];
166
kHiddenSize = attributes.qkvHiddenSizes[1];
167
vHiddenSize = attributes.qkvHiddenSizes[2];
170
const kvSequenceLength = sequenceLength;
172
if (qHiddenSize !== kHiddenSize) {
173
throw new Error('qkv_hidden_sizes first element should be same as the second');
176
if (bias.dims[0] !== qHiddenSize + kHiddenSize + vHiddenSize) {
177
throw new Error('Input "bias" dimension 0 should have same length as sum of Q/K/V hidden sizes');
180
let pastSequenceLength = 0;
182
if (kHiddenSize !== vHiddenSize) {
183
throw new Error('Input "past" expect k_hidden_size == v_hidden_size');
185
if (past.dims.length !== 5) {
186
throw new Error('Input "past" must have 5 dimensions');
188
if (past.dims[0] !== 2) {
189
throw new Error('Input "past" first dimension must be 2');
191
if (past.dims[1] !== batchSize) {
192
throw new Error('Input "past" second dimension must be batch_size');
194
if (past.dims[2] !== attributes.numHeads) {
195
throw new Error('Input "past" third dimension must be num_heads');
197
if (past.dims[4] !== kHiddenSize / attributes.numHeads) {
198
throw new Error('Input "past" fifth dimension must be k_hidden_size / num_heads');
201
if (!attributes.pastPresentShareBuffer) {
202
pastSequenceLength = past.dims[3];
204
// TODO: handle past_seq_len
207
const totalSequenceLength = kvSequenceLength + pastSequenceLength;
208
const maxSequenceLength = -1;
210
const maskType = AttentionMaskType.none;
212
// maskType = AttentionMaskType.MASK_UNKNOWN;
214
throw new Error('Mask not supported');
218
throw new Error('past is not supported');
222
if (attentionBias.dims.length !== 4) {
223
throw new Error('Input "attention_bias" must have 4 dimensions');
226
// TODO: support broadcasting the first and second dimensions of attention_bias
228
attentionBias.dims[0] !== batchSize ||
229
attentionBias.dims[1] !== attributes.numHeads ||
230
attentionBias.dims[2] !== sequenceLength ||
231
attentionBias.dims[3] !== totalSequenceLength
233
throw new Error('Expect "attention_bias" shape (batch_size, num_heads, sequence_length, total_sequence_length)');
245
hiddenSize: qHiddenSize,
247
headSize: Math.floor(qHiddenSize / attributes.numHeads),
248
vHeadSize: Math.floor(vHiddenSize / attributes.numHeads),
249
numHeads: attributes.numHeads,
250
isUnidirectional: false,
251
pastPresentShareBuffer: false,
252
maskFilterValue: attributes.maskFilterValue,
254
scale: attributes.scale,
255
broadcastResPosBias: false,
257
qkvFormat: AttentionQkvFormat.qkvBNSH,
261
const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number) => {
262
const components = getMaxComponents(d);
264
const dComp = d / components;
268
const elementsPerThread = Math.ceil(d / components / WG);
269
const programUniforms: ProgramUniform[] = [
270
{ type: DataType.float, data: 1 / d },
271
{ type: DataType.uint32, data: dComp },
272
{ type: DataType.uint32, data: elementsPerThread },
274
const dataType = tensorTypeToWsglStorageType(input.dataType, components);
275
const f32Type = tensorTypeToWsglValueType(DataType.float, components);
276
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type'];
277
const getShaderSource = (shaderHelper: ShaderHelper) => {
278
const inputHelper = outputVariable('x', input.dataType, input.dims, components);
279
const elemValueType = tensorTypeToWsglValueType(input.dataType);
280
const uniforms: UniformsArrayType = [
281
{ name: 'd_inv', type: 'f32' },
282
{ name: 'd_comp', type: 'u32' },
283
{ name: 'elements_per_thread', type: 'u32' },
287
var<workgroup> thread_max: array<f32, ${WG}>;
288
var<workgroup> thread_sum: array<f32, ${WG}>;
289
${shaderHelper.registerUniforms(uniforms).declareVariables(inputHelper)}
290
${shaderHelper.mainStart([WG, 1, 1])}
291
let local_offset = local_idx * uniforms.elements_per_thread;
292
let offset = (global_idx / ${WG}) * uniforms.d_comp + local_offset;
294
var thread_max_vector = ${f32Type}(-3.402823e+38f);
295
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
296
thread_max_vector = max(${f32Type}(x[offset + i]), thread_max_vector);
298
thread_max[local_idx] = ${(() => {
299
switch (components) {
301
return 'thread_max_vector';
303
return 'max(thread_max_vector.x, thread_max_vector.y)';
305
return 'max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))';
307
throw new Error(`Unsupported components: ${components}`);
312
var max_value = f32(-3.402823e+38f);
313
for (var i = 0u; i < ${WG}; i++) {
314
max_value = max(thread_max[i], max_value);
317
var sum_vector = ${f32Type}(0);
318
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
319
sum_vector += exp(${f32Type}(x[offset + i]) - max_value);
321
thread_sum[local_idx] = ${(() => {
322
switch (components) {
326
return 'sum_vector.x + sum_vector.y';
328
return 'sum_vector.x + sum_vector.y + sum_vector.z + sum_vector.w';
330
throw new Error(`Unsupported components: ${components}`);
336
for (var i = 0u; i < ${WG}; i++) {
337
sum += thread_sum[i];
341
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
342
x[offset + i] = ${inputHelper.type.value}(${elemValueType}(uniforms.d_inv));
345
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
346
var f32input = ${f32Type}(x[offset + i]);
347
x[offset + i] = ${inputHelper.type.value}(exp(f32input - max_value) / sum);
354
name: 'AttentionProbsSoftmax',
355
shaderCache: { hint: `${WG};${dataType};${components}`, inputDependencies },
357
getRunData: () => ({ outputs: [], dispatchGroup: { x: n }, programUniforms }),
361
const createAttentionProbsProgramInfo = (
365
pastKey: TensorView | undefined,
366
attentionBias: TensorView | undefined,
367
parameters: AttentionParameters,
368
attributes: AttentionAttrs,
369
pastSequenceLength: number,
371
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
372
const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength];
373
const presentKey = parameters.kvNumHeads === undefined && outputCount > 1 && pastKey;
374
const presentKeyShape = presentKey
375
? [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize]
380
const alpha = attributes.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : attributes.scale;
381
const components = getMaxComponents(parameters.headSize);
382
const vectorizedHeadSize = parameters.headSize / components;
383
const TILE_SIZE = 12;
385
x: Math.ceil(totalSequenceLength / TILE_SIZE),
386
y: Math.ceil(parameters.sequenceLength / TILE_SIZE),
387
z: parameters.batchSize * parameters.numHeads,
389
const programUniforms: ProgramUniform[] = [
390
{ type: DataType.uint32, data: parameters.sequenceLength },
391
{ type: DataType.uint32, data: vectorizedHeadSize },
392
{ type: DataType.uint32, data: totalSequenceLength },
393
{ type: DataType.uint32, data: parameters.numHeads },
394
{ type: DataType.float, data: alpha },
395
{ type: DataType.uint32, data: pastSequenceLength },
396
{ type: DataType.uint32, data: parameters.kvSequenceLength },
398
// Feed pastKey to the shader-code only if it is non-zero and presentKey is being produced
399
const feedPastKey = presentKey && pastKey && ShapeUtil.size(pastKey.dims) > 0;
400
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
402
inputDependencies.push('type');
405
inputDependencies.push('type');
407
const outputs = [{ dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default }];
409
outputs.push({ dims: presentKeyShape!, dataType: q.dataType, gpuDataType: GpuDataType.default });
411
const getShaderSource = (shaderHelper: ShaderHelper) => {
412
const qInput = inputVariable('q', q.dataType, q.dims, components);
413
const kInput = inputVariable('key', key.dataType, key.dims, components);
414
const inputVars = [qInput, kInput];
416
const pastKeyInput = inputVariable('past_key', pastKey.dataType, pastKey.dims, components);
417
inputVars.push(pastKeyInput);
420
inputVars.push(inputVariable('attention_bias', attentionBias.dataType, attentionBias.dims));
422
const output = outputVariable('output', q.dataType, probsShape);
423
const outputVars = [output];
425
outputVars.push(outputVariable('present_key', q.dataType, presentKeyShape!, components));
427
const f32Type = tensorTypeToWsglValueType(DataType.float, components);
429
const uniforms: UniformsArrayType = [
430
{ name: 'M', type: 'u32' },
431
{ name: 'K', type: 'u32' },
432
{ name: 'N', type: 'u32' },
433
{ name: 'num_heads', type: 'u32' },
434
{ name: 'alpha', type: 'f32' as UniformDataElementType },
435
{ name: 'past_sequence_length', type: 'u32' },
436
{ name: 'kv_sequence_length', type: 'u32' },
439
const TILE_SIZE = ${TILE_SIZE}u;
441
var<workgroup> tileQ: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>;
442
var<workgroup> tileK: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>;
443
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, ...outputVars)}
444
${shaderHelper.mainStart([TILE_SIZE, TILE_SIZE, 1])}
445
// x holds the N and y holds the M
446
let headIdx = workgroup_id.z;
447
let m = workgroup_id.y * TILE_SIZE;
448
let n = workgroup_id.x * TILE_SIZE;
449
let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K;
451
if (feedPastKey && presentKey) {
453
let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx;
454
let pastKeyOffset = uniforms.past_sequence_length * uniforms.K * headIdx;`;
457
let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;`;
460
${presentKey ? 'let presentKeyOffset = headIdx * uniforms.N * uniforms.K;' : ''}
461
var value = ${f32Type}(0);
462
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
463
if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {
464
tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];
466
if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {
467
var idx = TILE_SIZE * local_id.y + local_id.x;
469
if (feedPastKey && presentKey) {
471
if (n + local_id.y < uniforms.past_sequence_length) {
472
tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
475
key[kOffset + (n + local_id.y - uniforms.past_sequence_length) * uniforms.K + w + local_id.x];
478
return 'tileK[idx] = key[kOffset + local_id.y * uniforms.K + w + local_id.x];';
482
presentKey ? 'present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];' : ''
487
for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {
488
value += ${f32Type}(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]);
494
let headOffset = headIdx * uniforms.M * uniforms.N;
495
if (global_id.y < uniforms.M && global_id.x < uniforms.N) {
496
let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;
497
var sum: f32 = ${(() => {
498
switch (components) {
502
return 'value.x + value.y';
504
return 'value.x + value.y + value.z + value.w';
506
throw new Error(`Unsupported components: ${components}`);
509
output[outputIdx] = ${output.type.value} (sum * uniforms.alpha) + ${
510
attentionBias ? 'attention_bias[outputIdx]' : '0.0'
516
name: 'AttentionProbs',
518
hint: `${components};${attentionBias !== undefined};${pastKey !== undefined};${outputCount}`,
521
getRunData: () => ({ outputs, dispatchGroup: dispatch, programUniforms }),
526
const createVxAttentionScoreProgramInfo = (
530
pastValue: TensorView | undefined,
531
params: AttentionParameters,
532
pastSequenceLength: number,
534
const totalSequenceLength = pastSequenceLength + params.kvSequenceLength;
535
const nReps = params.nReps ? params.nReps : 1;
536
const repeatedVHiddenSize = params.vHiddenSize * nReps;
537
const presentValue = params.kvNumHeads == null && outputCount > 1 && pastValue;
538
const presentValueShape = presentValue
539
? [params.batchSize, params.numHeads, totalSequenceLength, params.headSize]
541
const outputShape = [params.batchSize, params.sequenceLength, repeatedVHiddenSize];
542
const TILE_SIZE = 12;
544
x: Math.ceil(params.vHeadSize / TILE_SIZE),
545
y: Math.ceil(params.sequenceLength / TILE_SIZE),
546
z: params.batchSize * params.numHeads,
549
const programUniforms: ProgramUniform[] = [
550
{ type: DataType.uint32, data: params.sequenceLength },
551
{ type: DataType.uint32, data: totalSequenceLength },
552
{ type: DataType.uint32, data: params.vHeadSize },
553
{ type: DataType.uint32, data: params.numHeads },
554
{ type: DataType.uint32, data: repeatedVHiddenSize },
555
{ type: DataType.uint32, data: pastSequenceLength },
556
{ type: DataType.uint32, data: params.kvSequenceLength },
558
// Feed pastValue to the shader-code only if it is non-empty and presentValue is being produced
559
const feedPastValue = presentValue && pastValue && ShapeUtil.size(pastValue.dims) > 0;
560
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
562
inputDependencies.push('type');
564
const outputs = [{ dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default }];
566
outputs.push({ dims: presentValueShape!, dataType: probs.dataType, gpuDataType: GpuDataType.default });
568
const getShaderSource = (shaderHelper: ShaderHelper) => {
569
const probsHelper = inputVariable('probs', probs.dataType, probs.dims);
570
const vHelper = inputVariable('v', v.dataType, v.dims);
571
const inputVars = [probsHelper, vHelper];
573
inputVars.push(inputVariable('past_value', pastValue.dataType, pastValue.dims));
575
const output = outputVariable('output', probs.dataType, outputShape);
576
const outputVars = [output];
578
outputVars.push(outputVariable('present_value', probs.dataType, presentValueShape!));
580
const uniforms: UniformsArrayType = [
581
{ name: 'M', type: 'u32' },
582
{ name: 'K', type: 'u32' },
583
{ name: 'N', type: 'u32' },
584
{ name: 'num_heads', type: 'u32' },
585
{ name: 'v_hidden_size', type: 'u32' },
586
{ name: 'past_sequence_length', type: 'u32' },
587
{ name: 'kv_sequence_length', type: 'u32' },
590
const TILE_SIZE = ${TILE_SIZE}u;
591
var<workgroup> tileQ: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>;
592
var<workgroup> tileK: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>;
593
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, ...outputVars)}
594
${shaderHelper.mainStart([TILE_SIZE, TILE_SIZE, 1])}
595
let headIdx = workgroup_id.z;
599
let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K;
601
if (feedPastValue && presentValue) {
603
let pastValueOffset = headIdx * uniforms.N * uniforms.past_sequence_length + n;
604
let vOffset = headIdx * uniforms.N * uniforms.kv_sequence_length + n;
608
let offsetB = headIdx * uniforms.N * uniforms.K + n;
612
${presentValue ? 'let presentValueOffset = headIdx * uniforms.N * uniforms.K + n;' : ''}
613
var value = ${probsHelper.type.storage}(0);
614
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
615
if (m < uniforms.M && w + local_id.x < uniforms.K) {
616
tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x];
618
if (n < uniforms.N && w + local_id.y < uniforms.K) {
619
var idx = TILE_SIZE * local_id.y + local_id.x;
621
if (feedPastValue && presentValue) {
623
if (w + local_id.y < uniforms.past_sequence_length) {
624
tileK[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N];
626
tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N];
631
tileK[idx] = v[offsetB + (w + local_id.y) * uniforms.N];
635
${presentValue ? 'present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];' : ''}
638
for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {
639
value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x];
644
// we need to transpose output from BNSH_v to BSND_v
645
let batchIdx = workgroup_id.z / uniforms.num_heads;
646
let currentBatchHeadNumber = workgroup_id.z % uniforms.num_heads;
647
if (m < uniforms.M && n < uniforms.N) {
648
let outputIdx = batchIdx * uniforms.M * uniforms.v_hidden_size + m * uniforms.v_hidden_size
649
+ currentBatchHeadNumber * uniforms.N + n;
650
output[outputIdx] = value;
656
name: 'AttentionScore',
657
shaderCache: { hint: `${pastValue !== undefined};${outputCount}`, inputDependencies },
658
getRunData: () => ({ outputs, dispatchGroup: dispatch, programUniforms }),
663
export const applyAttention = (
664
context: ComputeContext,
668
_maskIndex: TensorView | undefined,
669
_past: TensorView | undefined,
670
pastKey: TensorView | undefined,
671
pastValue: TensorView | undefined,
672
attentionBiasInput: TensorView | undefined,
673
parameters: AttentionParameters,
674
attributes: AttentionAttrs,
676
// Assumption is that presentKey/presentValue exists only if pastKey/pastValue exists.
677
const outputCount = Math.min(context.outputCount, 1 + (pastKey ? 1 : 0) + (pastValue ? 1 : 0));
678
const pastSequenceLength = parameters.kvNumHeads !== undefined || outputCount > 1 ? parameters.pastSequenceLength : 0;
679
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
680
const attentionBias =
681
attentionBiasInput && ShapeUtil.size(attentionBiasInput.dims) > 0 ? attentionBiasInput : undefined;
683
const inputsK = [q, k];
684
if (parameters.kvNumHeads === undefined && outputCount > 1 && pastKey && ShapeUtil.size(pastKey.dims) > 0) {
685
inputsK.push(pastKey);
688
inputsK.push(attentionBias);
691
// Run AttentionProbs
692
const probs = context.compute(
693
createAttentionProbsProgramInfo(
703
{ inputs: inputsK, outputs: parameters.kvNumHeads === undefined && outputCount > 1 ? [-1, 1] : [-1] },
708
createInPlaceSoftmaxProgramInfo(
710
parameters.batchSize * parameters.numHeads * parameters.sequenceLength,
713
{ inputs: [probs], outputs: [] },
717
const inputsV = [probs, v];
718
if (parameters.kvNumHeads === undefined && outputCount > 1 && pastValue && ShapeUtil.size(pastValue.dims) > 0) {
719
inputsV.push(pastValue);
721
context.compute(createVxAttentionScoreProgramInfo(outputCount, probs, v, pastValue, parameters, pastSequenceLength), {
723
outputs: parameters.kvNumHeads === undefined && outputCount > 1 ? [0, 2] : [0],
727
const prepare = (context: ComputeContext, parameters: AttentionParameters) => {
728
const outputShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, parameters.headSize];
729
const M = parameters.sequenceLength;
730
const K = parameters.inputHiddenSize;
731
const N = parameters.headSize;
732
const TILE_SIZE = 12;
734
x: Math.ceil(parameters.headSize / TILE_SIZE),
735
y: Math.ceil(parameters.sequenceLength / TILE_SIZE),
736
z: parameters.batchSize * parameters.numHeads,
738
const inputs = [context.inputs[0], context.inputs[1], context.inputs[2]];
739
const programUniforms: ProgramUniform[] = [
740
{ type: DataType.uint32, data: M },
741
{ type: DataType.uint32, data: K },
742
{ type: DataType.uint32, data: N },
743
{ type: DataType.uint32, data: parameters.numHeads },
744
{ type: DataType.uint32, data: parameters.headSize },
745
{ type: DataType.uint32, data: parameters.hiddenSize },
746
{ type: DataType.uint32, data: parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize },
749
const getShaderSource = (shaderHelper: ShaderHelper) => {
750
const outputQ = outputVariable('output_q', inputs[0].dataType, outputShape);
751
const outputK = outputVariable('output_k', inputs[0].dataType, outputShape);
752
const outputV = outputVariable('output_v', inputs[0].dataType, outputShape);
753
const input = inputVariable('input', inputs[0].dataType, inputs[0].dims);
754
const weight = inputVariable('weight', inputs[1].dataType, inputs[1].dims);
755
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims);
756
const dataType = input.type.storage;
758
const uniforms: UniformsArrayType = [
759
{ name: 'M', type: 'u32' },
760
{ name: 'K', type: 'u32' },
761
{ name: 'N', type: 'u32' },
762
{ name: 'num_heads', type: 'u32' },
763
{ name: 'head_size', type: 'u32' },
764
{ name: 'hidden_size', type: 'u32' },
765
{ name: 'ldb', type: 'u32' },
768
const TILE_SIZE = ${TILE_SIZE}u;
769
var<workgroup> tileInput: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>;
770
var<workgroup> tileWeightQ: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>;
771
var<workgroup> tileWeightK: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>;
772
var<workgroup> tileWeightV: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>;
773
${shaderHelper.registerUniforms(uniforms).declareVariables(input, weight, bias, outputQ, outputK, outputV)}
774
${shaderHelper.mainStart([TILE_SIZE, TILE_SIZE, 1])}
775
let batchIndex = workgroup_id.z / uniforms.num_heads;
776
let headNumber = workgroup_id.z % uniforms.num_heads;
780
let inputOffset = batchIndex * (uniforms.M * uniforms.K) + m * uniforms.K;
781
let biasOffsetQ = headNumber * uniforms.head_size;
782
let biasOffsetK = uniforms.hidden_size + biasOffsetQ;
783
let biasOffsetV = uniforms.hidden_size + biasOffsetK;
785
var valueQ = ${dataType}(0);
786
var valueK = ${dataType}(0);
787
var valueV = ${dataType}(0);
788
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
789
if (m < uniforms.M && w + local_id.x < uniforms.K) {
790
tileInput[TILE_SIZE * local_id.y + local_id.x] = input[inputOffset + w + local_id.x];
792
if (n < uniforms.N && w + local_id.y < uniforms.K) {
793
let offset = n + (w + local_id.y) * uniforms.ldb;
794
tileWeightQ[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetQ + offset];
795
tileWeightK[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetK + offset];
796
tileWeightV[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetV + offset];
799
for (var k: u32 = 0u; k<TILE_SIZE && w+k < uniforms.K; k++) {
800
let inputTileOffset = TILE_SIZE * local_id.y + k;
801
let weightTileOffset = TILE_SIZE * k + local_id.x;
802
valueQ += tileInput[inputTileOffset] * tileWeightQ[weightTileOffset];
803
valueK += tileInput[inputTileOffset] * tileWeightK[weightTileOffset];
804
valueV += tileInput[inputTileOffset] * tileWeightV[weightTileOffset];
810
let headOffset = (m * uniforms.N + n) % uniforms.head_size;
811
valueQ += bias[headOffset + biasOffsetQ];
812
valueK += bias[headOffset + biasOffsetK];
813
valueV += bias[headOffset + biasOffsetV];
815
let offset = workgroup_id.z * uniforms.M * uniforms.N;
816
if (m < uniforms.M && n < uniforms.N) {
817
let outputIdx = offset + m * uniforms.N + n;
818
output_q[outputIdx] = valueQ;
819
output_k[outputIdx] = valueK;
820
output_v[outputIdx] = valueV;
825
return context.compute(
827
name: 'AttentionPrepare',
828
shaderCache: { inputDependencies: ['type', 'type', 'type'] },
831
{ dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default },
832
{ dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default },
833
{ dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default },
835
dispatchGroup: dispatch,
840
{ inputs, outputs: [-1, -1, -1] },
844
export const attention = (context: ComputeContext, attributes: AttentionAttrs): void => {
845
const params = validateAttentionInputs(context.inputs, attributes);
847
const [q, k, v] = prepare(context, params);
849
return applyAttention(