1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { createAttributeWithCacheKey } from '../attribute-with-cache-key';
8
import { ComputeContext, ProgramInfo, ProgramUniform } from '../types';
10
import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, WORKGROUP_SIZE } from './common';
12
export interface RotaryEmbeddingAttributes {
13
readonly interleaved: boolean;
14
readonly numHeads: number;
15
readonly rotaryEmbeddingDim: number;
16
readonly scale: number;
19
const validateInputs = (inputs: readonly TensorView[], attributes: RotaryEmbeddingAttributes): void => {
20
const [input, positionIds, cosCache, sinCache] = inputs;
21
const { numHeads, rotaryEmbeddingDim } = attributes;
23
if (input.dims.length !== 3 && input.dims.length !== 4) {
24
throw new Error(`Input 'x' is expected to have 3 or 4 dimensions, got ${input.dims.length}`);
27
!ShapeUtil.areEqual(positionIds.dims, []) &&
28
!ShapeUtil.areEqual(positionIds.dims, [1]) &&
29
positionIds.dims.length !== 2
31
throw new Error(`Input 'position_ids' is expected to have 0, 1, or 2 dimensions, got ${positionIds.dims.length}`);
33
if (cosCache.dims.length !== 2) {
34
throw new Error(`Input 'cos_cache' is expected to have 2 dimensions, got ${cosCache.dims.length}`);
36
if (sinCache.dims.length !== 2) {
37
throw new Error(`Input 'sin_cache' is expected to have 2 dimensions, got ${sinCache.dims.length}`);
39
if (!ShapeUtil.areEqual(cosCache.dims, sinCache.dims)) {
40
throw new Error("Inputs 'cos_cache' and 'sin_cache' are expected to have the same shape");
43
if (rotaryEmbeddingDim > 0 && numHeads === 0) {
44
throw new Error('num_heads must be provided if rotary_embedding_dim is specified');
47
const batchSize = input.dims[0];
48
const sequenceLength = input.dims[input.dims.length - 2];
49
const maxSequenceLength = cosCache.dims[0];
50
const hiddenSize = ShapeUtil.sizeFromDimension(input.dims, 1) / sequenceLength;
51
const headSize = rotaryEmbeddingDim === 0 ? cosCache.dims[1] * 2 : hiddenSize / numHeads;
52
if (rotaryEmbeddingDim > headSize) {
53
throw new Error('rotary_embedding_dim must be less than or equal to head_size');
56
if (positionIds.dims.length === 2) {
57
if (batchSize !== positionIds.dims[0]) {
58
throw new Error(`Input 'position_ids' dimension 0 should be of size batch_size, got ${positionIds.dims[0]}`);
60
if (sequenceLength !== positionIds.dims[1]) {
61
throw new Error(`Input 'position_ids' dimension 1 should be of size sequence_length, got ${positionIds.dims[1]}`);
65
if (headSize / 2 !== cosCache.dims[1] && rotaryEmbeddingDim / 2 !== cosCache.dims[1]) {
67
`Input 'cos_cache' dimension 1 should be same as head_size / 2 or rotary_embedding_dim / 2, got ${
73
if (sequenceLength > maxSequenceLength) {
74
throw new Error('Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported');
78
const createRotaryEmbeddingProgramInfo = (
79
inputs: readonly TensorView[],
80
attributes: RotaryEmbeddingAttributes,
82
const { interleaved, numHeads, rotaryEmbeddingDim, scale } = attributes;
83
const batchSize = inputs[0].dims[0];
84
const batchStride = ShapeUtil.sizeFromDimension(inputs[0].dims, 1);
85
const sequenceLength = inputs[0].dims[inputs[0].dims.length - 2];
86
const hiddenSize = batchStride / sequenceLength;
87
const halfRotaryEmbeddingDim = inputs[2].dims[1];
88
const headSize = rotaryEmbeddingDim === 0 ? halfRotaryEmbeddingDim * 2 : hiddenSize / numHeads;
90
// Rotary embeddings will be calculated in a pair-wise fashion. In accordance, use the shape
91
// [batch size, sequence length, num of heads, num of pairs to rotate + num of dims to copy]
92
// to unfold the global index in shader.
93
const globalShape = new Array<number>(
96
hiddenSize / headSize,
97
headSize - halfRotaryEmbeddingDim,
99
const globalStrides = ShapeUtil.computeStrides(globalShape);
101
const programUniforms: ProgramUniform[] = [
102
{ type: DataType.float, data: scale },
103
{ type: DataType.uint32, data: globalShape },
104
{ type: DataType.uint32, data: globalStrides },
106
// strides for addressing the input/output tensor, in permutated order to align with the unfolded global index,
108
...(inputs[0].dims.length === 3
109
? new Array<ProgramUniform>({ type: DataType.uint32, data: [batchStride, hiddenSize, headSize, 1] })
111
...(inputs[0].dims.length === 4
112
? new Array<ProgramUniform>({
113
type: DataType.uint32,
114
data: [batchStride, headSize, sequenceLength * headSize, 1],
118
...createTensorShapeVariables(inputs[0].dims, inputs[1].dims, inputs[2].dims, inputs[3].dims, inputs[0].dims),
121
const getShaderSource = (shaderHelper: ShaderHelper) => {
122
const input = inputVariable('input', inputs[0].dataType, inputs[0].dims.length);
123
const positionIds = inputVariable('position_ids', inputs[1].dataType, inputs[1].dims.length);
124
const cosCache = inputVariable('cos_cache', inputs[2].dataType, inputs[2].dims.length);
125
const sinCache = inputVariable('sin_cache', inputs[3].dataType, inputs[3].dims.length);
126
const output = outputVariable('output', inputs[0].dataType, inputs[0].dims.length);
128
shaderHelper.registerUniforms([
129
{ name: 'scale', type: 'f32' },
130
{ name: 'global_shape', type: 'u32', length: globalShape.length },
131
{ name: 'global_strides', type: 'u32', length: globalStrides.length },
132
{ name: 'input_output_strides', type: 'u32', length: globalStrides.length },
136
${shaderHelper.declareVariables(input, positionIds, cosCache, sinCache, output)}
138
${shaderHelper.mainStart(WORKGROUP_SIZE)}
139
let half_rotary_emb_dim = uniforms.${cosCache.name}_shape[1];
140
let bsnh = global_idx / uniforms.global_strides % uniforms.global_shape;
141
let size = uniforms.global_shape[0] * uniforms.global_strides[0];
142
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('size')}
144
if (bsnh[3] < half_rotary_emb_dim) {
145
let position_ids_idx =
146
${positionIds.broadcastedIndicesToOffset('bsnh.xy', outputVariable('', positionIds.type.tensor, 2))};
148
u32(${positionIds.getByOffset('position_ids_idx')}) + select(0, bsnh[1], position_ids_idx == 0);
149
let i = dot(bsnh, uniforms.input_output_strides) + select(0, bsnh[3], ${interleaved});
150
let j = i + select(half_rotary_emb_dim, 1, ${interleaved});
151
let re = ${input.getByOffset('i')} * ${cosCache.get('position_id', 'bsnh[3]')} -
152
${input.getByOffset('j')} * ${sinCache.get('position_id', 'bsnh[3]')};
153
${output.setByOffset('i', 're')}
154
let im = ${input.getByOffset('i')} * ${sinCache.get('position_id', 'bsnh[3]')} +
155
${input.getByOffset('j')} * ${cosCache.get('position_id', 'bsnh[3]')};
156
${output.setByOffset('j', 'im')}
158
let k = dot(bsnh, uniforms.input_output_strides) + half_rotary_emb_dim;
159
${output.setByOffset('k', input.getByOffset('k'))}
165
name: 'RotaryEmbedding',
167
hint: createAttributeWithCacheKey({
170
inputDependencies: ['rank', 'rank', 'rank', 'rank'],
174
outputs: [{ dims: inputs[0].dims, dataType: inputs[0].dataType }],
175
dispatchGroup: { x: Math.ceil(ShapeUtil.size(globalShape) / WORKGROUP_SIZE) },
181
export const rotaryEmbedding = (context: ComputeContext, attributes: RotaryEmbeddingAttributes): void => {
182
validateInputs(context.inputs, attributes);
183
context.compute(createRotaryEmbeddingProgramInfo(context.inputs, attributes));