4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { createAttributeWithCacheKey } from '../attribute-with-cache-key';
8
import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types';
17
import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType } from './common';
18
import { maybeTransposeToBNSHAndAddBias } from './multihead-attention';
19
import { createTileProgramInfo } from './tile';
20
import { createTransposeProgramInfo, TransposeAttributes } from './transpose';
22
export const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => {
23
const query = inputs[0];
24
const key = inputs[1];
25
const value = inputs[2];
26
const pastKey = inputs[3];
27
const pastValue = inputs[4];
58
if (query.dims.length !== 3 && query.dims.length !== 5) {
59
throw new Error('Input query is expected to have 3 or 5 dimensions');
62
const dmmhaPacking = false;
63
const batchSize = query.dims[0];
64
const sequenceLength = query.dims[1];
66
query.dims.length === 3 ? (dmmhaPacking ? query.dims[2] / 3 : query.dims[2]) : attributes.numHeads * query.dims[4];
67
let kvSequenceLength = sequenceLength;
69
let pastSequenceLength = 0;
70
let maxSequenceLength = 0;
71
const headSize = Math.floor(hiddenSize / attributes.numHeads);
72
const hasPastKey = pastKey && pastKey.dims.length !== 0;
73
const hasPastValue = pastValue && pastValue.dims.length !== 0;
75
const isPastkvBSNH = true;
76
if (hasPastKey && hasPastValue) {
77
if (pastKey.dims.length !== 4) {
78
throw new Error('Input "past_key" is expected to have 4 dimensions');
80
if (pastValue.dims.length !== 4) {
81
throw new Error('Input "past_value" is expected to have 4 dimensions');
85
pastSequenceLength = pastKey.dims[1];
86
maxSequenceLength = pastKey.dims[1];
89
pastSequenceLength = pastKey.dims[2];
90
maxSequenceLength = pastKey.dims[2];
92
} else if (hasPastKey || hasPastValue) {
93
throw new Error('Input "past_key" and "past_value" shall be both present or both absent');
96
let qkvFormat: AttentionQkvFormat;
98
if (query.dims.length !== 3) {
99
throw new Error('Input "query" is expected to have 3 dimensions when key is given');
101
if (key.dims.length < 3 || key.dims.length > 5) {
102
throw new Error('Input "key" is expected to have 3, 4, or 5 dimensions');
104
if (query.dims[0] !== key.dims[0]) {
105
throw new Error('Input "query" and "key" shall have same dim 0 (batch size)');
108
if (key.dims.length === 3) {
109
if (query.dims[2] % key.dims[2] !== 0) {
110
throw new Error('Dimension 2 of "query" should be a multiple of "key"');
112
qkvFormat = AttentionQkvFormat.qkvBSNH;
113
kvSequenceLength = key.dims[1];
114
} else if (key.dims.length === 5) {
115
if (key.dims[2] !== attributes.numHeads || key.dims[3] !== 2 || key.dims[4] !== headSize) {
116
throw new Error('Expect "key" shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv');
119
throw new Error('Expect "value" be none when "key" has packed kv format.');
121
qkvFormat = AttentionQkvFormat.qKvBSNHxBSN2H;
122
kvSequenceLength = key.dims[1];
125
if (key.dims[1] !== attributes.numHeads || key.dims[3] !== headSize) {
126
throw new Error('Expect "key" shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key');
129
qkvFormat = AttentionQkvFormat.unknown;
130
kvSequenceLength = key.dims[2];
134
if (query.dims.length !== 3 && query.dims.length !== 5) {
135
throw new Error('Input "query" is expected to have 3 or 5 dimensions when key is empty');
137
if (query.dims.length === 5 && (query.dims[2] !== attributes.numHeads || query.dims[3] !== 3)) {
138
throw new Error('Expect "query" shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv');
141
qkvFormat = AttentionQkvFormat.qkvBSN3H;
144
const maskType: AttentionMaskType = AttentionMaskType.none;
145
let passPastInKv = false;
146
let vHiddenSize = hiddenSize;
148
if (value.dims.length !== 3 && value.dims.length !== 4) {
149
throw new Error('Input "value" is expected to have 3 or 4 dimensions');
152
if (query.dims[0] !== value.dims[0]) {
153
throw new Error('Input "query" and "value" shall have same dim 0 (batch_size)');
156
if (value.dims.length === 3) {
157
if (kvSequenceLength !== value.dims[1]) {
158
throw new Error('Input "key" and "value" shall have the same dim 1 (kv_sequence_length)');
160
vHiddenSize = value.dims[2];
162
if (kvSequenceLength !== value.dims[2]) {
163
throw new Error('Input "past_key" and "past_value" shall have the same dim 2 (kv_sequence_length)');
165
vHiddenSize = value.dims[1] * value.dims[3];
169
const totalSequenceLength = pastSequenceLength + kvSequenceLength;
170
const broadcastResPosBias = false;
183
vHeadSize: Math.floor(vHiddenSize / attributes.kvNumHeads!),
184
numHeads: attributes.numHeads,
185
kvNumHeads: attributes.kvNumHeads,
186
nReps: attributes.numHeads / attributes.kvNumHeads!,
187
pastPresentShareBuffer: false,
189
scale: attributes.scale,
197
const createConcatProgramInfo = (
199
b: TensorView | undefined,
201
params: AttentionParameters,
203
const outputShape = [params.batchSize, params.totalSequenceLength, params.kvNumHeads!, params.headSize];
205
const outputSize = ShapeUtil.size(outputShape) / component;
206
const presentSequenceLength = params.totalSequenceLength;
207
const output = outputVariable('present_kv', dataType, outputShape.length, component);
208
const inputA = inputVariable('new_kv', a.dataType, a.dims.length, component);
209
const inputB = b ? inputVariable('past_kv', b.dataType, b.dims.length, component) : undefined;
211
const H = Math.ceil(params.headSize / component);
212
const dispatch = { x: presentSequenceLength, y: a.dims[0], z: 1 };
214
const inputDependencies: ProgramInputTensorInfoDependency[] = b ? ['rank', 'rank'] : ['rank'];
216
const programUniforms: ProgramUniform[] = [
217
{ type: DataType.uint32, data: outputSize },
218
{ type: DataType.uint32, data: params.pastSequenceLength },
219
{ type: DataType.uint32, data: params.kvSequenceLength },
220
{ type: DataType.uint32, data: params.totalSequenceLength },
223
const inputs = [inputA];
225
programUniforms.push(
226
...createTensorShapeVariables(a.dims),
227
...createTensorShapeVariables(b!.dims),
228
...createTensorShapeVariables(outputShape),
232
programUniforms.push(...createTensorShapeVariables(a.dims), ...createTensorShapeVariables(outputShape));
234
const uniforms: UniformsArrayType = [
235
{ name: 'output_size', type: 'u32' },
236
{ name: 'past_seqlen', type: 'u32' },
237
{ name: 'new_seqlen', type: 'u32' },
238
{ name: 'present_seqlen', type: 'u32' },
241
const pastStr = ` let past_batch_stride = uniforms.past_seqlen * num_heads * H;
242
var past_head_stride = uniforms.past_seqlen * H;
244
past_head_stride = H;
246
let in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h;
247
present_kv[out_offset] = past_kv[in_offset];`;
248
const newStr = ` let new_batch_stride = uniforms.new_seqlen * num_heads * H;
249
let new_row_stride = num_heads * H;
250
let new_head_stride = H;
251
let in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h;
252
present_kv[out_offset] = new_kv[in_offset];`;
254
? `if (s < past_seqlen) {
256
} else if (s < past_seqlen + uniforms.new_seqlen) {
259
: `if (s < past_seqlen + uniforms.new_seqlen) {
264
const getShaderSource = (shaderHelper: ShaderHelper) => `
266
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputs, output)}
267
${shaderHelper.mainStart([H, params.kvNumHeads!, 1])}
268
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
269
var indices = ${output.offsetToIndices('global_idx')};
272
let s = workgroup_id.x;
273
let b = workgroup_id.y;
274
let num_heads = ${params.kvNumHeads!}u;
277
let present_seqlen = uniforms.present_seqlen;
278
let present_batch_stride = present_seqlen * num_heads * H;
280
let is_bsnh = ${params.isPastkvBSNH};
283
row_stride = num_heads * H;
285
var present_head_stride = present_seqlen * H;
287
present_head_stride = H;
290
let past_seqlen = uniforms.past_seqlen;
292
let out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h;
297
name: 'ConcatPastNew',
298
shaderCache: { hint: `${params.kvNumHeads!}${H}${!!b}`, inputDependencies },
300
outputs: [{ dims: outputShape, dataType }],
301
dispatchGroup: dispatch,
308
export const parseGroupQueryAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs =>
309
createAttributeWithCacheKey({ ...attributes });
311
const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({ perm: [0, 2, 1, 3] });
313
const maybeExpandAndTransposeToBNSH = (
314
context: ComputeContext,
316
pastKV: TensorView | undefined,
317
params: AttentionParameters,
320
let reshapedInput = input;
321
const numHeads = params.kvNumHeads!;
322
const nReps = params.nReps!;
323
if (input.dims.length === 3 && params.kvSequenceLength !== 0) {
324
reshapedInput = input.reshape([params.batchSize, params.kvSequenceLength, numHeads, params.headSize]);
328
reshapedInput = context.compute(createConcatProgramInfo(reshapedInput, pastKV, reshapedInput.dataType, params), {
329
inputs: [reshapedInput, pastKV],
330
outputs: [params.isPastkvBSNH ? outputIndex : -1],
333
reshapedInput = context.compute(createConcatProgramInfo(reshapedInput, undefined, reshapedInput.dataType, params), {
334
inputs: [reshapedInput],
335
outputs: [params.isPastkvBSNH ? outputIndex : -1],
339
reshapedInput = context.compute(createTileProgramInfo([reshapedInput], [1, 1, 1, nReps]), {
340
inputs: [reshapedInput],
343
reshapedInput = reshapedInput.reshape([
345
params.totalSequenceLength,
351
return context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), {
352
inputs: [reshapedInput],
357
export const groupQueryAttention = (context: ComputeContext, attributes: AttentionAttrs): void => {
358
const params = validateInputs(context.inputs, attributes);
359
if (context.inputs[0].dims.length === 5) {
360
throw new Error('Packed QKV is not implemented');
363
if (context.inputs[1]?.dims.length === 5) {
364
throw new Error('Packed KV is not implemented');
367
const Q = maybeTransposeToBNSHAndAddBias(
371
params.sequenceLength,
377
const pastKey = context.inputs[3] && context.inputs[3].dims.length !== 0 ? context.inputs[3] : undefined;
378
const pastValue = context.inputs[4] && context.inputs[4].dims.length !== 0 ? context.inputs[4] : undefined;
379
const K = maybeExpandAndTransposeToBNSH(context, context.inputs[1], pastKey, params, 1);
380
const V = maybeExpandAndTransposeToBNSH(context, context.inputs[2], pastValue, params, 2);
381
applyAttention(context, Q, K, V, undefined, undefined, undefined, undefined, undefined, params, attributes);