1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { Tensor } from '../../../tensor';
5
import { BroadcastUtil, ShapeUtil } from '../../../util';
6
import { getGlsl } from '../glsl-source';
7
import { WebGLInferenceHandler } from '../inference-handler';
8
import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types';
9
import { getCoordsDataType, getGlChannels } from '../utils';
11
import { getActivationSnippet, InternalActivationAttributes } from './fuse-utils';
12
import { getBiasForMatmul } from './matmul';
14
const createPackedMatmulProgramMetadata = (hasBias: boolean, cacheHint: string) => ({
15
name: 'MatMul (packed)',
16
inputNames: hasBias ? ['A', 'B', 'Bias'] : ['A', 'B'],
18
? [TextureType.packed, TextureType.packed, TextureType.packed]
19
: [TextureType.packed, TextureType.packed],
23
const createPackedMatmulProgramInfo = (
24
inferenceHandler: WebGLInferenceHandler,
25
metadata: ProgramMetadata,
27
activationAttributes: InternalActivationAttributes,
29
const hasBias = inputs.length > 2;
30
const processBias = hasBias ? 'value += getBiasForMatmul();' : '';
31
const aShape = inputs[0].dims;
32
const bShape = inputs[1].dims;
33
const outputShape = BroadcastUtil.calcShape(aShape, bShape, true);
34
const isBroadcast = !ShapeUtil.areEqual(inputs[0].dims, inputs[1].dims);
37
throw new Error("Can't use matmul on the given tensors");
39
const sharedDim = aShape[aShape.length - 1];
40
const sharedDimIndex = Math.ceil(sharedDim / 2);
41
const aRank = aShape.length;
42
const bRank = bShape.length;
44
const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
45
const coordsDataType = getCoordsDataType(outputShape.length);
46
const outRank = outputShape.length;
47
const allGlChannels = getGlChannels();
48
const { activationFunction, applyActivation } = getActivationSnippet(activationAttributes);
50
const getBiasForMatmulSnippet = hasBias
51
? `${getBiasForMatmul(coordsDataType, allGlChannels, inputs[2].dims, outputShape, true)}`
54
const getBcastedSamplerForMatmulSnippet = isBroadcast
55
? `${getBcastSamplerForMatmul(coordsDataType, allGlChannels, inputs, outputShape)}`
58
const getSamplerAInLoopSnippet = isBroadcast ? 'getAAtOutCoordsMatmul(i)' : `getA(${getA(allGlChannels, aRank)})`;
59
const getSamplerBInLoopSnippet = isBroadcast ? 'getBAtOutCoordsMatmul(i)' : `getB(${getB(allGlChannels, bRank)})`;
60
const getOutputCoordsSnippet = isBroadcast
62
: `${coordsDataType} rc =
63
getOutputCoords(); int lastDim = rc.${allGlChannels[outRank - 1]}; rc.${allGlChannels[outRank - 1]} =
64
rc.${allGlChannels[outRank - 2]}; rc.${allGlChannels[outRank - 2]} = lastDim;
66
const shaderSource = `
67
${getBcastedSamplerForMatmulSnippet}
68
${getBiasForMatmulSnippet}
71
${getOutputCoordsSnippet}
74
for (int i = 0; i < ${sharedDimIndex}; i++) {
75
vec4 a = ${getSamplerAInLoopSnippet};
76
vec4 b = ${getSamplerBInLoopSnippet};
78
value += (a.rrbb * b.rgrg);
79
value += (a.ggaa * b.baba);
83
${glsl.output} = value;
87
output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.packed },
93
export const createPackedMatmulProgramInfoLoader = (
94
inferenceHandler: WebGLInferenceHandler,
96
activationAttributes: InternalActivationAttributes,
97
): ProgramInfoLoader => {
98
const metadata = createPackedMatmulProgramMetadata(inputs.length > 2, activationAttributes.activationCacheKey);
101
get: () => createPackedMatmulProgramInfo(inferenceHandler, metadata, inputs, activationAttributes),
105
function getBcastSamplerForMatmul(
106
coordsDataType: string,
107
allGlChannels: readonly string[],
109
outShape: readonly number[],
111
let unpackedACoordsSnippet = [];
112
let unpackedBCoordsSnippet = [];
114
const inAShape = inputs[0].dims;
115
const inBShape = inputs[1].dims;
117
const inARank = inAShape.length;
118
const inBRank = inBShape.length;
120
const outRank = outShape.length;
121
const rankADiff = outRank - inARank;
122
const rankBDiff = outRank - inBRank;
124
unpackedACoordsSnippet = inAShape.map((_s, i) => `coords.${allGlChannels[i + rankADiff]}`);
125
unpackedACoordsSnippet[inARank - 1] = 'i*2';
126
unpackedACoordsSnippet.join(', ');
127
unpackedBCoordsSnippet = inBShape.map((_s, i) => `coords.${allGlChannels[i + rankBDiff]}`);
128
unpackedBCoordsSnippet[inBRank - 2] = 'i*2';
129
unpackedBCoordsSnippet.join(', ');
131
const broadcastADims = BroadcastUtil.getBroadcastDims(inAShape, outShape);
132
const broadcastBDims = BroadcastUtil.getBroadcastDims(inBShape, outShape);
134
const coordsASnippet = broadcastADims.map((d) => `coords.${allGlChannels[d + rankADiff]} = 0;`).join('\n');
135
const coordsBSnippet = broadcastBDims.map((d) => `coords.${allGlChannels[d + rankBDiff]} = 0;`).join('\n');
136
const swapDimSnippet = `int lastDim = coords.${allGlChannels[outRank - 1]};
137
coords.${allGlChannels[outRank - 1]} = coords.${allGlChannels[outRank - 2]};
138
coords.${allGlChannels[outRank - 2]} = lastDim;`;
140
const getBcastSamplerMatmulSource = `
141
vec4 getAAtOutCoordsMatmul(int i) {
142
${coordsDataType} coords = getOutputCoords();
145
vec4 outputValue = getA(${unpackedACoordsSnippet});
149
vec4 getBAtOutCoordsMatmul(int i) {
150
${coordsDataType} coords = getOutputCoords();
153
vec4 outputValue = getB(${unpackedBCoordsSnippet});
157
return getBcastSamplerMatmulSource;
160
function getA(allGlChannels: string[], rank: number): string {
162
for (let i = 0; i < rank - 2; i++) {
163
res += `rc.${allGlChannels[i]}, `;
165
res += `rc.${allGlChannels[rank - 2]}, ` + 'i*2';
169
function getB(allGlChannels: string[], rank: number): string {
171
for (let i = 0; i < rank - 2; i++) {
172
res += `rc.${allGlChannels[i]}, `;
174
res += 'i*2, ' + `rc.${allGlChannels[rank - 1]}`;