1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { Graph } from '../../../graph';
5
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
6
import { Tensor } from '../../../tensor';
7
import { BroadcastUtil, ShapeUtil } from '../../../util';
8
import { WebGLInferenceHandler } from '../inference-handler';
9
import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types';
10
import { getCoordsDataType, getGlChannels } from '../utils';
12
import { getActivationSnippet, InternalActivationAttributes, parseInternalActivationAttributes } from './fuse-utils';
13
import { createPackedMatmulProgramInfoLoader } from './matmul-pack';
15
export const matMul: OperatorImplementation<InternalActivationAttributes> = (
16
inferenceHandler: WebGLInferenceHandler,
18
attributes: InternalActivationAttributes,
20
validateInputs(inputs);
22
if (inferenceHandler.session.pack) {
23
return [inferenceHandler.run(createPackedMatmulProgramInfoLoader(inferenceHandler, inputs, attributes), inputs)];
25
return [inferenceHandler.run(createMatmulProgramInfoLoader(inputs, attributes), inputs)];
29
export const parseMatMulAttributes: OperatorInitialization<InternalActivationAttributes> = (
31
): InternalActivationAttributes => parseInternalActivationAttributes(node.attributes);
33
const createMatmulProgramMetadata = (hasBias: boolean, cacheHint: string) => ({
35
inputNames: hasBias ? ['A', 'B', 'Bias'] : ['A', 'B'],
37
? [TextureType.unpacked, TextureType.unpacked, TextureType.unpacked]
38
: [TextureType.unpacked, TextureType.unpacked],
42
function createMatmulProgramInfo(
43
metadata: ProgramMetadata,
45
activationAttributes: InternalActivationAttributes,
47
const aShape = inputs[0].dims;
48
const bShape = inputs[1].dims;
49
const outputShape = BroadcastUtil.calcShape(aShape, bShape, true);
51
throw new Error("Can't use matmul on the given tensors");
53
const coordsDataType = getCoordsDataType(outputShape.length);
54
const allGlChannels = getGlChannels();
55
const { activationFunction, applyActivation } = getActivationSnippet(activationAttributes);
57
const hasBias = inputs.length > 2;
58
const processBias = hasBias ? 'value += getBiasForMatmul();' : '';
59
const getBiasForMatmulSnippet = hasBias
60
? `${getBiasForMatmul(coordsDataType, allGlChannels, inputs[2].dims, outputShape, false)}`
63
const rank = outputShape.length;
64
const arank = aShape.length;
65
const brank = bShape.length;
66
const sharedDim = aShape[aShape.length - 1];
67
const shaderSource = `
69
${getBiasForMatmulSnippet}
70
float process(int indices[${rank}]) {
73
bcastMatmulIndices_A(indices, a);
74
bcastMatmulIndices_B(indices, b);
77
for (int k=0; k<${sharedDim}; ++k) {
80
value += _A(a) * _B(b);
88
output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked },
93
export function createMatmulProgramInfoLoader(
95
activationAttributes: InternalActivationAttributes,
97
const metadata = createMatmulProgramMetadata(inputs.length > 2, activationAttributes.activationCacheKey);
98
return { ...metadata, get: () => createMatmulProgramInfo(metadata, inputs, activationAttributes) };
101
const validateInputs = (inputs: Tensor[]): void => {
102
if (!inputs || inputs.length !== 2) {
103
throw new Error('MatMul requires 2 inputs.');
106
if (inputs[0].dims[inputs[0].dims.length - 1] !== inputs[1].dims[inputs[1].dims.length - 2]) {
107
throw new Error('shared dimension does not match.');
111
(inputs[0].type !== 'float32' && inputs[0].type !== 'float64') ||
112
(inputs[1].type !== 'float32' && inputs[1].type !== 'float64')
114
throw new Error('inputs should be float type');
117
if (inputs[0].type !== inputs[1].type) {
118
throw new Error('inputs types should match');
122
export function getBiasForMatmul(
123
coordsDataType: string,
124
allGlChannels: readonly string[],
125
inShape: readonly number[],
126
outShape: readonly number[],
129
let unpackedCoordsSnippet = '';
130
const inRank = inShape.length;
131
const outRank = outShape.length;
132
const rankDiff = outRank - inRank;
133
if (outRank < 2 && inRank > 0) {
134
unpackedCoordsSnippet = 'coords';
136
unpackedCoordsSnippet = inShape.map((_s, i) => `coords.${allGlChannels[i + rankDiff]}`).join(', ');
138
const broadcastDims = BroadcastUtil.getBroadcastDims(inShape, outShape);
139
const coordsSnippet = broadcastDims.map((d) => `coords.${allGlChannels[d + rankDiff]} = 0;`).join('\n');
140
const inSize = ShapeUtil.size(inShape);
141
const isInputScalar = inSize === 1;
142
let output = 'vec4(outputValue.xx, outputValue.yy)';
144
output = 'vec4(outputValue.x)';
146
const getBiasForMatmulSource = isPacked
148
vec4 getBiasForMatmul() {
149
${coordsDataType} coords = getOutputCoords();
151
vec4 outputValue = getBias(${unpackedCoordsSnippet});
155
float getBiasForMatmul() {
156
${coordsDataType} coords = getOutputCoords();
158
return getBias(coords.x);
161
return getBiasForMatmulSource;