1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { GemmUtil, ShapeUtil } from '../../util';
7
import { AttributeWithCacheKey } from '../attribute-with-cache-key';
8
import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types';
11
createTensorShapeVariables,
19
const validateInputs = (inputs: readonly TensorView[]): void => {
21
throw new Error('Input is missing');
23
if (inputs.length < 2 || inputs.length > 3) {
24
throw new Error('Invaid input number.');
27
// 'C' can be of dimensionality 0, 1 or 2 only
28
if (inputs.length === 3 && inputs[2].dims.length > 2) {
29
throw new Error('Invalid input shape of C');
32
if (inputs[0].dataType !== inputs[1].dataType || (inputs.length === 3 && inputs[0].dataType !== inputs[2].dataType)) {
33
throw new Error('Input types are mismatched');
37
export interface GemmAttributes extends AttributeWithCacheKey {
44
const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAttributes): ProgramInfo => {
45
const aShape = inputs[0].dims.slice();
46
const bShape = inputs[1].dims.slice();
47
const [M, N, K] = GemmUtil.getShapeOfGemmResult(
52
inputs.length === 3 ? inputs[2].dims : undefined,
54
const outputShape = [M, N];
56
throw new Error("Can't use gemm on the given tensors");
58
const outputSize = ShapeUtil.size(outputShape);
59
const programUniforms: ProgramUniform[] = [
60
{ type: DataType.uint32, data: outputSize },
61
{ type: DataType.uint32, data: M },
62
{ type: DataType.uint32, data: N },
63
{ type: DataType.uint32, data: K },
64
{ type: DataType.float, data: attributes.alpha },
65
{ type: DataType.float, data: attributes.beta },
67
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
68
if (inputs.length === 3) {
69
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
70
inputDependencies.push('rank');
72
programUniforms.push(...createTensorShapeVariables(outputShape));
74
const getShaderSource = (shaderHelper: ShaderHelper) => {
76
if (attributes.transA && attributes.transB) {
77
line = 'value += a[k * uniforms.M + m] * b[n * uniforms.K + k];';
78
} else if (attributes.transA && !attributes.transB) {
79
line = 'value += a[k * uniforms.M + m] * b[k * uniforms.N + n];';
80
} else if (!attributes.transA && attributes.transB) {
81
line = 'value += a[m * uniforms.K + k] * b[n * uniforms.K + k];';
82
} else if (!attributes.transA && !attributes.transB) {
83
line = 'value += a[m * uniforms.K + k] * b[k * uniforms.N + n];';
86
const calculateAlpha = attributes.alpha === 1 ? '' : 'value *= uniforms.alpha;';
87
const a = inputVariable('a', inputs[0].dataType, inputs[0].dims);
88
const b = inputVariable('b', inputs[1].dataType, inputs[1].dims);
89
const dataType = a.type.value;
90
let c: IndicesHelper | null = null;
91
const variables = [a, b];
92
if (inputs.length === 3) {
93
c = inputVariable('c', inputs[2].dataType, inputs[2].dims.length);
96
const output = outputVariable('output', inputs[0].dataType, outputShape.length);
97
variables.push(output);
98
const uniforms: UniformsArrayType = [
99
{ name: 'output_size', type: 'u32' },
100
{ name: 'M', type: 'u32' },
101
{ name: 'N', type: 'u32' },
102
{ name: 'K', type: 'u32' },
103
{ name: 'alpha', type: 'f32' },
104
{ name: 'beta', type: 'f32' },
107
${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)}
109
${shaderHelper.mainStart()}
110
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
112
let m = global_idx / uniforms.N;
113
let n = global_idx % uniforms.N;
115
var value = ${dataType}(0);
116
for (var k: u32 = 0u; k < uniforms.K; k++) {
123
return `let cOffset = ${c.broadcastedIndicesToOffset('vec2(m, n)', output)}; value += ${
125
}(uniforms.beta) * ${c.getByOffset('cOffset')};`;
129
output[global_idx] = value;
135
shaderCache: { hint: `${attributes.cacheKey}`, inputDependencies },
137
outputs: [{ dims: outputShape, dataType: inputs[0].dataType }],
138
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
145
export const parseGemmAttributes = (attributes: Record<string, unknown>): GemmAttributes => {
146
const transA = attributes.transA as boolean;
147
const transB = attributes.transB as boolean;
148
const alpha = attributes.alpha as number;
149
const beta = attributes.beta as number;
155
cacheKey: `${attributes.transA};${attributes.transB};${attributes.alpha === 1}`,
159
export const gemm = (context: ComputeContext, attributes: GemmAttributes): void => {
160
validateInputs(context.inputs);
161
context.compute(createGemmProgramInfo(context.inputs, attributes));