1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
8
import { ComputeContext, ProgramInfo, ProgramUniform } from '../types';
10
import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper } from './common';
12
export interface EinsumAttributes extends AttributeWithCacheKey {
13
readonly equation: string;
15
// The equation attribute value is a string which consists of left hand side (LHS) and optionally right hand side (RHS)
16
// separated by '->'. Ex. "ij,jk -> ik" expresses matrix multiplication
17
// "ij->ji" expresses matrix transpose
18
// "ii->i" diagonal elements of a square matrix
19
// LHS consists of a sequence of terms separated by commas. Each term corresponds to an input variable.
20
// Each symbol corresponds to a dimension in the input variable. The symbol can be either a letter, 'a' to 'z' or 'A' to
21
// 'Z' or '...' to represent arbitrary dimensions.
23
const symbolPattern = '[a-zA-Z]|\\.\\.\\.'; // The pattern each symbol in each term in the symbolic equation should match
24
const termPattern = '(' + symbolPattern + ')+'; // The pattern each term in the symbolic equation should match
25
const termPatternOnly = '^' + termPattern + '$'; // The patterns only matchs a term begin to end.
26
const lhsPattern = '(' + termPattern + ',)*' + termPattern; // The pattern the LHS should match
27
const lhsPatternOnly = '^' + lhsPattern + '$'; // The patterns only matchs a LHS begin to end.
30
count: number; // Symbol corresponding to a dimmension of an input
31
inputIndices: number[]; // Number of input variables the symbol corresponds to
32
dimValue: number; // Number of dimensions the symbol corresponds to
36
constructor(inputIndex = -1) {
37
this.symbolToIndices = new Map<string, number[]>();
38
this.inputIndex = inputIndex;
41
// Add a symbol to the term
42
addSymbol(symbol: string, index: number) {
43
let value = this.symbolToIndices.get(symbol);
44
if (value === undefined) {
49
this.symbolToIndices.set(symbol, value);
52
symbolToIndices: Map<string, number[]>; // Map from symbol to dimensions of the input corresponding to the term
53
inputIndex: number; // -1 for output and 0, 1, 2, ... for inputs
58
inputs: readonly TensorView[],
59
public readonly equation: string,
61
this.hasEllipsis = false;
62
this.symbolToInfo = new Map<string, SymbolInfo>();
63
this.lhs = new Array<EinsumTerm>();
65
// As rhs needs to be updated allow using let instead of const for both lhs and rhs.
66
// eslint-disable-next-line prefer-const
67
let [lhs, rhs] = equation.includes('->') ? equation.split('->', 2) : [equation, ''];
68
if (!lhs.match(RegExp(lhsPatternOnly))) {
69
throw new Error('Invalid LHS term');
71
const inputTerms = lhs.split(',');
72
inputTerms.forEach((inputTerm, index) => {
73
const dims = inputs[index].dims.slice();
74
if (!inputTerm.match(RegExp(termPatternOnly))) {
75
throw new Error('Invalid LHS term');
77
const einsumTerm = this.processTerm(inputTerm, true, dims, index);
78
this.lhs.push(einsumTerm);
81
// Initialize the RHS if not specified
83
// Construct RHS from LHS terms/symbols
84
rhs += [...this.symbolToInfo.entries()]
85
.filter(([sym, info]) => info.count === 1 || sym === '...')
89
if (!rhs.match(RegExp(termPattern))) {
90
throw new Error('Invalid RHS');
94
// Compute output dims
95
const rhsSymbols = rhs.match(RegExp(symbolPattern, 'g'));
96
rhsSymbols?.forEach((symbol) => {
97
if (symbol === '...') {
98
this.outputDims = this.outputDims.concat(this.ellipsisDims);
100
const info = this.symbolToInfo.get(symbol);
101
if (info === undefined) {
102
throw new Error('Invalid RHS symbol');
104
this.outputDims.push(info.dimValue);
107
this.rhs = this.processTerm(rhs, false, this.outputDims);
108
} // End of EinsumEqation constructor
110
// Add a symbol to the equation
111
addSymbol(symbol: string, dimValue: number, inputIndex: number) {
112
let info = this.symbolToInfo.get(symbol);
113
if (info !== undefined) {
114
if (info.dimValue !== dimValue && info.count !== 1) {
115
throw new Error('Dimension mismatch');
118
info.inputIndices.push(inputIndex);
121
info = { count: 1, dimValue, inputIndices: [inputIndex] };
123
this.symbolToInfo.set(symbol, info);
126
// Process one input/output term
127
processTerm(term: string, isInput: boolean, dims: readonly number[], index = -1): EinsumTerm {
128
const rank = dims.length;
129
let ellipsis = false;
130
let ellipsisDims = [];
132
// For output empty string is allowed because the output may be reduced to a scalar value
133
if (!term.match(RegExp(termPatternOnly)) && !isInput && term !== '') {
134
throw new Error('Invalid LHS term');
136
const indexSymbols = term.match(RegExp(symbolPattern, 'g'));
137
const einsumTerm = new EinsumTerm(index);
138
// symbol can be either a lettre, 'a' to 'z' or 'A' to 'Z', or '...'
139
indexSymbols?.forEach((symbol: string, i: number) => {
140
if (symbol === '...') {
142
throw new Error('Only one ellipsis is allowed per input term');
145
const ellipsisDimLength = rank - indexSymbols.length + 1;
146
if (ellipsisDimLength < 0) {
147
throw new Error('Ellipsis out of bounds');
149
ellipsisDims = dims.slice(nextDim, nextDim + ellipsisDimLength);
150
if (this.hasEllipsis) {
152
this.ellipsisDims.length !== ellipsisDims.length ||
153
this.ellipsisDims.toString() !== ellipsisDims.toString()
155
throw new Error('Ellipsis dimensions mismatch');
157
} else if (isInput) {
158
this.hasEllipsis = true;
159
this.ellipsisDims = ellipsisDims;
161
throw new Error('Ellipsis must be specified in the LHS');
163
// Add '0', '1', '2', '3', '4', etc to represent ellipsis dimensions to avoid special handling
164
for (let j = 0; j < ellipsisDims.length; j++) {
165
const symbol = String.fromCharCode('0'.charCodeAt(0) + j);
166
einsumTerm.addSymbol(symbol, i + j);
167
this.addSymbol(symbol, dims[nextDim++], index);
170
einsumTerm.addSymbol(symbol, i + (this.hasEllipsis ? this.ellipsisDims.length - 1 : 0));
171
this.addSymbol(symbol, dims[nextDim++], index);
177
symbolToInfo: Map<string, SymbolInfo>; // All symbols in the equation
178
hasEllipsis: boolean; // The equation has ellipsis or not
179
ellipsisDims: number[]; // The dimensions of the equation ellipsis corresponds to.
180
lhs: EinsumTerm[]; // Terms on the left-hand side of the equation
181
rhs: EinsumTerm; // Term on the right-hand side of the equation
182
outputDims: number[]; // Output dimensions of the equation
183
} // End of class EinsumEquation
185
const appendMax = (name: string): string => name + '_max';
187
const createEinsumProgramInfo = (
188
inputShapes: Array<readonly number[]>,
190
einsumEquation: EinsumEquation,
191
outputShape: readonly number[],
193
const ranks = inputShapes.map((dims) => dims.length);
194
const inputVars = ranks.map((rank, index) => inputVariable(`input${index}`, dataType, rank));
195
const outputSize = ShapeUtil.size(outputShape);
196
const output = outputVariable('output', dataType, outputShape.length);
197
const uniformsSymbols = [...einsumEquation.symbolToInfo.keys()].filter(
198
(symbol) => !einsumEquation.rhs.symbolToIndices.has(symbol),
200
const getShaderSource = (shaderHelper: ShaderHelper) => {
201
const idxCopy: string[] = [];
202
const initProd = 'var prod = 1.0;';
203
const initSum = 'var sum = 0.0;';
204
const updateSum = 'sum += prod;';
205
const reduceOpsSetIndices: string[] = [];
206
const reduceOpsLoopHeaders: string[] = [];
207
const reduceOpsLoopFooters: string[] = [];
208
const reduceOpCompute: string[] = [];
209
const isReduceOpsWithoutLoop = einsumEquation.symbolToInfo.size === einsumEquation.rhs.symbolToIndices.size;
210
einsumEquation.symbolToInfo.forEach((info, symbol) => {
211
if (einsumEquation.rhs.symbolToIndices.has(symbol)) {
212
const outputIndex = einsumEquation.rhs.symbolToIndices.get(symbol)?.[0];
213
if (outputIndex !== undefined) {
214
einsumEquation.lhs.forEach((term, i) => {
215
if (info.inputIndices.includes(i)) {
216
const indices = term.symbolToIndices.get(symbol);
217
if (indices === undefined) {
218
throw new Error('Invalid symbol error');
220
indices.forEach((index) => {
222
`${inputVars[i].indicesSet(
225
output.indicesGet('outputIndices', outputIndex),
233
einsumEquation.lhs.forEach((term, i) => {
234
if (info.inputIndices.includes(i)) {
235
const indices = term.symbolToIndices.get(symbol);
236
if (indices === undefined) {
237
throw new Error('Invalid symbol error');
239
indices.forEach((index) => {
240
reduceOpsSetIndices.push(`${inputVars[i].indicesSet(`input${i}Indices`, index, `${symbol}`)}`);
242
reduceOpCompute.push(`prod *= ${inputVars[i].getByIndices(`input${i}Indices`)};`);
245
reduceOpsLoopHeaders.push(
246
`for(var ${symbol}: u32 = 0; ${symbol} < uniforms.${appendMax(symbol)}; ${symbol}++) {`,
248
reduceOpsLoopFooters.push('}');
251
const reduceOps = isReduceOpsWithoutLoop
254
`let sum = ${inputVars.map((inputVar, i) => inputVar.getByIndices(`input${i}Indices`)).join(' * ')};`,
259
...reduceOpsLoopHeaders,
260
...reduceOpsSetIndices,
264
...reduceOpsLoopFooters,
268
.registerUniforms(uniformsSymbols.map((symbol) => ({ name: `${appendMax(symbol)}`, type: 'u32' })))
269
.registerUniform('outputSize', 'u32')
270
.declareVariables(...inputVars, output)}
272
${shaderHelper.mainStart()}
273
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
274
var outputIndices = ${output.offsetToIndices('global_idx')};
275
${inputVars.map((_var, i) => `var input${i}Indices: ${inputVars[i].type.indices};`).join('\n')}
276
${reduceOps.join('\n')};
277
${output.setByOffset('global_idx', 'sum')};
282
shaderCache: { hint: einsumEquation.equation, inputDependencies: inputShapes.map(() => 'rank') },
284
// The symbols from uniformSymbols array are guaranteed to exist in einsumEquations.symbolToInfo map. The
285
// filter is added to make sure that dimValue is never 0.
286
const programUniformsInit: ProgramUniform[] = uniformsSymbols
287
.filter((symbol) => einsumEquation.symbolToInfo.has(symbol))
288
.map((symbol) => ({ type: DataType.uint32, data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0 }));
289
programUniformsInit.push({ type: DataType.uint32, data: outputSize });
290
const programUniforms: ProgramUniform[] = inputShapes
291
.map((dims, _) => [...createTensorShapeVariables(dims)])
292
.reduce((acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), programUniformsInit);
293
programUniforms.push(...createTensorShapeVariables(outputShape));
295
outputs: [{ dims: outputShape, dataType }],
296
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
304
export const einsum = (context: ComputeContext, attributes: EinsumAttributes): void => {
305
const einsumEquation = new EinsumEquation(context.inputs, attributes.equation);
306
const outputShape = einsumEquation.outputDims;
307
const inputShapes = context.inputs.map((input, _) => input.dims);
308
context.compute(createEinsumProgramInfo(inputShapes, context.inputs[0].dataType, einsumEquation, outputShape));
311
export const parseEinsumAttributes = (attributes: Record<string, unknown>): EinsumAttributes => {
312
const equation = (attributes.equation as string).replace(/\s+/g, '');
313
return createAttributeWithCacheKey({ equation });