onnxruntime

Форк
0
314 строк · 12.9 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
8
import { ComputeContext, ProgramInfo, ProgramUniform } from '../types';
9

10
import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper } from './common';
11

12
export interface EinsumAttributes extends AttributeWithCacheKey {
13
  readonly equation: string;
14
}
15
// The equation attribute value is a string which consists of left hand side (LHS) and optionally right hand side (RHS)
16
// separated by '->'. Ex. "ij,jk -> ik" expresses matrix multiplication
17
//     "ij->ji" expresses matrix transpose
18
//      "ii->i" diagonal elements of a square matrix
19
// LHS consists of a sequence of terms separated by commas. Each term corresponds to an input variable.
20
// Each symbol corresponds to a dimension in the input variable. The symbol can be either a letter, 'a' to 'z' or 'A' to
21
// 'Z' or '...' to represent arbitrary dimensions.
22

23
const symbolPattern = '[a-zA-Z]|\\.\\.\\.'; // The pattern each symbol in each term in the symbolic equation should match
24
const termPattern = '(' + symbolPattern + ')+'; // The pattern each term in the symbolic equation should match
25
const termPatternOnly = '^' + termPattern + '$'; // The patterns only matchs a term begin to end.
26
const lhsPattern = '(' + termPattern + ',)*' + termPattern; // The pattern the LHS should match
27
const lhsPatternOnly = '^' + lhsPattern + '$'; // The patterns only matchs a LHS begin to end.
28

29
interface SymbolInfo {
30
  count: number; // Symbol corresponding to a dimmension of an input
31
  inputIndices: number[]; // Number of input variables the symbol corresponds to
32
  dimValue: number; // Number of dimensions the symbol corresponds to
33
}
34

35
class EinsumTerm {
36
  constructor(inputIndex = -1) {
37
    this.symbolToIndices = new Map<string, number[]>();
38
    this.inputIndex = inputIndex;
39
  }
40

41
  // Add a symbol to the term
42
  addSymbol(symbol: string, index: number) {
43
    let value = this.symbolToIndices.get(symbol);
44
    if (value === undefined) {
45
      value = [index];
46
    } else {
47
      value.push(index);
48
    }
49
    this.symbolToIndices.set(symbol, value);
50
  }
51

52
  symbolToIndices: Map<string, number[]>; // Map from symbol to dimensions of the input corresponding to the term
53
  inputIndex: number; // -1 for output and 0, 1, 2, ... for inputs
54
}
55

56
class EinsumEquation {
57
  constructor(
58
    inputs: readonly TensorView[],
59
    public readonly equation: string,
60
  ) {
61
    this.hasEllipsis = false;
62
    this.symbolToInfo = new Map<string, SymbolInfo>();
63
    this.lhs = new Array<EinsumTerm>();
64
    this.outputDims = [];
65
    // As rhs needs to be updated allow using let instead of const for both lhs and rhs.
66
    // eslint-disable-next-line prefer-const
67
    let [lhs, rhs] = equation.includes('->') ? equation.split('->', 2) : [equation, ''];
68
    if (!lhs.match(RegExp(lhsPatternOnly))) {
69
      throw new Error('Invalid LHS term');
70
    }
71
    const inputTerms = lhs.split(',');
72
    inputTerms.forEach((inputTerm, index) => {
73
      const dims = inputs[index].dims.slice();
74
      if (!inputTerm.match(RegExp(termPatternOnly))) {
75
        throw new Error('Invalid LHS term');
76
      }
77
      const einsumTerm = this.processTerm(inputTerm, true, dims, index);
78
      this.lhs.push(einsumTerm);
79
    });
80

81
    // Initialize the RHS if not specified
82
    if (rhs === '') {
83
      // Construct RHS from LHS terms/symbols
84
      rhs += [...this.symbolToInfo.entries()]
85
        .filter(([sym, info]) => info.count === 1 || sym === '...')
86
        .map(([sym]) => sym)
87
        .join('');
88
    } else {
89
      if (!rhs.match(RegExp(termPattern))) {
90
        throw new Error('Invalid RHS');
91
      }
92
    }
93

94
    // Compute output dims
95
    const rhsSymbols = rhs.match(RegExp(symbolPattern, 'g'));
96
    rhsSymbols?.forEach((symbol) => {
97
      if (symbol === '...') {
98
        this.outputDims = this.outputDims.concat(this.ellipsisDims);
99
      } else {
100
        const info = this.symbolToInfo.get(symbol);
101
        if (info === undefined) {
102
          throw new Error('Invalid RHS symbol');
103
        }
104
        this.outputDims.push(info.dimValue);
105
      }
106
    });
107
    this.rhs = this.processTerm(rhs, false, this.outputDims);
108
  } // End of EinsumEqation constructor
109

110
  // Add a symbol to the equation
111
  addSymbol(symbol: string, dimValue: number, inputIndex: number) {
112
    let info = this.symbolToInfo.get(symbol);
113
    if (info !== undefined) {
114
      if (info.dimValue !== dimValue && info.count !== 1) {
115
        throw new Error('Dimension mismatch');
116
      } else {
117
        info.count++;
118
        info.inputIndices.push(inputIndex);
119
      }
120
    } else {
121
      info = { count: 1, dimValue, inputIndices: [inputIndex] };
122
    }
123
    this.symbolToInfo.set(symbol, info);
124
  }
125

126
  // Process one input/output term
127
  processTerm(term: string, isInput: boolean, dims: readonly number[], index = -1): EinsumTerm {
128
    const rank = dims.length;
129
    let ellipsis = false;
130
    let ellipsisDims = [];
131
    let nextDim = 0;
132
    // For output empty string is allowed because the output may be reduced to a scalar value
133
    if (!term.match(RegExp(termPatternOnly)) && !isInput && term !== '') {
134
      throw new Error('Invalid LHS term');
135
    }
136
    const indexSymbols = term.match(RegExp(symbolPattern, 'g'));
137
    const einsumTerm = new EinsumTerm(index);
138
    // symbol can be either a lettre, 'a' to 'z' or 'A' to 'Z', or '...'
139
    indexSymbols?.forEach((symbol: string, i: number) => {
140
      if (symbol === '...') {
141
        if (ellipsis) {
142
          throw new Error('Only one ellipsis is allowed per input term');
143
        }
144
        ellipsis = true;
145
        const ellipsisDimLength = rank - indexSymbols.length + 1;
146
        if (ellipsisDimLength < 0) {
147
          throw new Error('Ellipsis out of bounds');
148
        }
149
        ellipsisDims = dims.slice(nextDim, nextDim + ellipsisDimLength);
150
        if (this.hasEllipsis) {
151
          if (
152
            this.ellipsisDims.length !== ellipsisDims.length ||
153
            this.ellipsisDims.toString() !== ellipsisDims.toString()
154
          ) {
155
            throw new Error('Ellipsis dimensions mismatch');
156
          }
157
        } else if (isInput) {
158
          this.hasEllipsis = true;
159
          this.ellipsisDims = ellipsisDims;
160
        } else {
161
          throw new Error('Ellipsis must be specified in the LHS');
162
        }
163
        // Add '0', '1', '2', '3', '4', etc to represent ellipsis dimensions to avoid special handling
164
        for (let j = 0; j < ellipsisDims.length; j++) {
165
          const symbol = String.fromCharCode('0'.charCodeAt(0) + j);
166
          einsumTerm.addSymbol(symbol, i + j);
167
          this.addSymbol(symbol, dims[nextDim++], index);
168
        }
169
      } else {
170
        einsumTerm.addSymbol(symbol, i + (this.hasEllipsis ? this.ellipsisDims.length - 1 : 0));
171
        this.addSymbol(symbol, dims[nextDim++], index);
172
      }
173
    });
174
    return einsumTerm;
175
  }
176

177
  symbolToInfo: Map<string, SymbolInfo>; // All symbols in the equation
178
  hasEllipsis: boolean; // The equation has ellipsis or not
179
  ellipsisDims: number[]; // The dimensions of the equation ellipsis corresponds to.
180
  lhs: EinsumTerm[]; // Terms on the left-hand side of the equation
181
  rhs: EinsumTerm; // Term on the right-hand side of the equation
182
  outputDims: number[]; // Output dimensions of the equation
183
} // End of class EinsumEquation
184

185
const appendMax = (name: string): string => name + '_max';
186

187
const createEinsumProgramInfo = (
188
  inputShapes: Array<readonly number[]>,
189
  dataType: number,
190
  einsumEquation: EinsumEquation,
191
  outputShape: readonly number[],
192
): ProgramInfo => {
193
  const ranks = inputShapes.map((dims) => dims.length);
194
  const inputVars = ranks.map((rank, index) => inputVariable(`input${index}`, dataType, rank));
195
  const outputSize = ShapeUtil.size(outputShape);
196
  const output = outputVariable('output', dataType, outputShape.length);
197
  const uniformsSymbols = [...einsumEquation.symbolToInfo.keys()].filter(
198
    (symbol) => !einsumEquation.rhs.symbolToIndices.has(symbol),
199
  );
200
  const getShaderSource = (shaderHelper: ShaderHelper) => {
201
    const idxCopy: string[] = [];
202
    const initProd = 'var prod = 1.0;';
203
    const initSum = 'var sum = 0.0;';
204
    const updateSum = 'sum += prod;';
205
    const reduceOpsSetIndices: string[] = [];
206
    const reduceOpsLoopHeaders: string[] = [];
207
    const reduceOpsLoopFooters: string[] = [];
208
    const reduceOpCompute: string[] = [];
209
    const isReduceOpsWithoutLoop = einsumEquation.symbolToInfo.size === einsumEquation.rhs.symbolToIndices.size;
210
    einsumEquation.symbolToInfo.forEach((info, symbol) => {
211
      if (einsumEquation.rhs.symbolToIndices.has(symbol)) {
212
        const outputIndex = einsumEquation.rhs.symbolToIndices.get(symbol)?.[0];
213
        if (outputIndex !== undefined) {
214
          einsumEquation.lhs.forEach((term, i) => {
215
            if (info.inputIndices.includes(i)) {
216
              const indices = term.symbolToIndices.get(symbol);
217
              if (indices === undefined) {
218
                throw new Error('Invalid symbol error');
219
              }
220
              indices.forEach((index) => {
221
                idxCopy.push(
222
                  `${inputVars[i].indicesSet(
223
                    `input${i}Indices`,
224
                    index,
225
                    output.indicesGet('outputIndices', outputIndex),
226
                  )}`,
227
                );
228
              });
229
            }
230
          });
231
        }
232
      } else {
233
        einsumEquation.lhs.forEach((term, i) => {
234
          if (info.inputIndices.includes(i)) {
235
            const indices = term.symbolToIndices.get(symbol);
236
            if (indices === undefined) {
237
              throw new Error('Invalid symbol error');
238
            }
239
            indices.forEach((index) => {
240
              reduceOpsSetIndices.push(`${inputVars[i].indicesSet(`input${i}Indices`, index, `${symbol}`)}`);
241
            });
242
            reduceOpCompute.push(`prod *= ${inputVars[i].getByIndices(`input${i}Indices`)};`);
243
          }
244
        });
245
        reduceOpsLoopHeaders.push(
246
          `for(var ${symbol}: u32 = 0; ${symbol} < uniforms.${appendMax(symbol)}; ${symbol}++) {`,
247
        );
248
        reduceOpsLoopFooters.push('}');
249
      }
250
    });
251
    const reduceOps = isReduceOpsWithoutLoop
252
      ? [
253
          ...idxCopy,
254
          `let sum = ${inputVars.map((inputVar, i) => inputVar.getByIndices(`input${i}Indices`)).join(' * ')};`,
255
        ]
256
      : [
257
          ...idxCopy,
258
          initSum,
259
          ...reduceOpsLoopHeaders,
260
          ...reduceOpsSetIndices,
261
          initProd,
262
          ...reduceOpCompute,
263
          updateSum,
264
          ...reduceOpsLoopFooters,
265
        ];
266
    return `
267
            ${shaderHelper
268
              .registerUniforms(uniformsSymbols.map((symbol) => ({ name: `${appendMax(symbol)}`, type: 'u32' })))
269
              .registerUniform('outputSize', 'u32')
270
              .declareVariables(...inputVars, output)}
271

272
            ${shaderHelper.mainStart()}
273
            ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
274
            var outputIndices = ${output.offsetToIndices('global_idx')};
275
            ${inputVars.map((_var, i) => `var input${i}Indices: ${inputVars[i].type.indices};`).join('\n')}
276
            ${reduceOps.join('\n')};
277
            ${output.setByOffset('global_idx', 'sum')};
278
          }`;
279
  };
280
  return {
281
    name: 'Einsum',
282
    shaderCache: { hint: einsumEquation.equation, inputDependencies: inputShapes.map(() => 'rank') },
283
    getRunData: () => {
284
      // The symbols from uniformSymbols array are guaranteed to exist in einsumEquations.symbolToInfo map. The
285
      // filter is added to make sure that dimValue is never 0.
286
      const programUniformsInit: ProgramUniform[] = uniformsSymbols
287
        .filter((symbol) => einsumEquation.symbolToInfo.has(symbol))
288
        .map((symbol) => ({ type: DataType.uint32, data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0 }));
289
      programUniformsInit.push({ type: DataType.uint32, data: outputSize });
290
      const programUniforms: ProgramUniform[] = inputShapes
291
        .map((dims, _) => [...createTensorShapeVariables(dims)])
292
        .reduce((acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), programUniformsInit);
293
      programUniforms.push(...createTensorShapeVariables(outputShape));
294
      return {
295
        outputs: [{ dims: outputShape, dataType }],
296
        dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
297
        programUniforms,
298
      };
299
    },
300
    getShaderSource,
301
  };
302
};
303

304
export const einsum = (context: ComputeContext, attributes: EinsumAttributes): void => {
305
  const einsumEquation = new EinsumEquation(context.inputs, attributes.equation);
306
  const outputShape = einsumEquation.outputDims;
307
  const inputShapes = context.inputs.map((input, _) => input.dims);
308
  context.compute(createEinsumProgramInfo(inputShapes, context.inputs[0].dataType, einsumEquation, outputShape));
309
};
310

311
export const parseEinsumAttributes = (attributes: Record<string, unknown>): EinsumAttributes => {
312
  const equation = (attributes.equation as string).replace(/\s+/g, '');
313
  return createAttributeWithCacheKey({ equation });
314
};
315

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.