1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
5
import { Graph } from '../../../graph';
6
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
7
import { Tensor } from '../../../tensor';
8
import { GemmUtil } from '../../../util';
9
import { WebGLInferenceHandler } from '../inference-handler';
10
import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types';
12
export interface GemmAttributes extends AttributeWithCacheKey {
17
isOptionalC: boolean; // in opset 11, C becomes optional
20
export const gemm: OperatorImplementation<GemmAttributes> = (
21
inferenceHandler: WebGLInferenceHandler,
23
attributes: GemmAttributes,
25
validateInputs(inputs, attributes);
26
const output = inferenceHandler.run(createGemmProgramInfoLoader(inputs, attributes), inputs);
30
const parseGemmAttributes = (node: Graph.Node, isOptionalC: boolean): GemmAttributes => {
31
const transA = node.attributes.getInt('transA', 0) !== 0;
32
const transB = node.attributes.getInt('transB', 0) !== 0;
33
const alpha = node.attributes.getFloat('alpha', 1.0);
34
const beta = node.attributes.getFloat('beta', 1.0);
35
return createAttributeWithCacheKey({ transA, transB, alpha, beta, isOptionalC });
38
export const parseGemmAttributesV7: OperatorInitialization<GemmAttributes> = (node: Graph.Node): GemmAttributes =>
39
parseGemmAttributes(node, false);
41
export const parseGemmAttributesV11: OperatorInitialization<GemmAttributes> = (node: Graph.Node): GemmAttributes =>
42
parseGemmAttributes(node, true);
44
const createGemmProgramInfoLoader = (inputs: Tensor[], attributes: GemmAttributes): ProgramInfoLoader => {
47
inputNames: inputs.length === 3 ? ['A', 'B', 'C'] : ['A', 'B'],
50
? [TextureType.unpacked, TextureType.unpacked, TextureType.unpacked]
51
: [TextureType.unpacked, TextureType.unpacked],
52
key: attributes.cacheKey,
55
return { ...metadata, get: () => createGemmProgramInfo(metadata, inputs, attributes) };
58
const createGemmProgramInfo = (
59
metadata: ProgramMetadata,
61
attributes: GemmAttributes,
63
const aShape = inputs[0].dims.slice();
64
const bShape = inputs[1].dims.slice();
65
const [M, N] = GemmUtil.getShapeOfGemmResult(
70
inputs.length === 3 ? inputs[2].dims : undefined,
72
const outputShape = [M, N];
74
throw new Error("Can't use gemm on the given tensors");
76
let sharedDim = aShape[aShape.length - 1];
78
if (attributes.transA) {
79
sharedDim = aShape[0];
81
if (attributes.transA && attributes.transB) {
82
line = 'value += _A_T(a) * _B_T(b);';
83
} else if (attributes.transA && !attributes.transB) {
84
line = 'value += _A_T(a) * _B(b);';
85
} else if (!attributes.transA && attributes.transB) {
86
line = 'value += _A(a) * _B_T(b);';
87
} else if (!attributes.transA && !attributes.transB) {
88
line = 'value += _A(a) * _B(b);';
90
const rank = outputShape.length;
91
const declareC = inputs.length === 3 ? `int c[${inputs[2].dims.length}];` : '';
92
const broadcastC = inputs.length === 3 ? 'bcastIndices_C(indices, c);' : '';
93
const calculateC = inputs.length === 3 ? 'value += beta * _C(c);' : '';
94
const shaderSource = `
95
float process(int indices[${rank}]) {
105
for (int k=0; k<${sharedDim}; ++k) {
111
value = value * alpha;
117
output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked },
119
{ name: 'alpha', type: 'float', data: attributes.alpha },
120
{ name: 'beta', type: 'float', data: attributes.beta },
126
const validateInputs = (inputs: Tensor[], attributes: GemmAttributes): void => {
128
throw new Error('Input is missing');
130
if (attributes.isOptionalC && (inputs.length < 2 || inputs.length > 3)) {
131
throw new Error('Invaid input shape.');
133
if (!attributes.isOptionalC && inputs.length !== 3) {
134
throw new Error('Gemm requires 3 inputs');
137
// 'C' can be of dimensionality 1 or 2 only
138
if (inputs.length === 3 && inputs[2].dims.length !== 1 && inputs[2].dims.length !== 2) {
139
throw new Error('Invalid input shape of C');
143
(inputs[0].type !== 'float32' && inputs[0].type !== 'float64') ||
144
(inputs[1].type !== 'float32' && inputs[1].type !== 'float64') ||
145
(inputs.length === 3 && inputs[2].type !== 'float32' && inputs[2].type !== 'float64')
147
throw new Error('Invalid input type.');
150
if (inputs[0].type !== inputs[1].type || (inputs.length === 3 && inputs[0].type !== inputs[2].type)) {
151
throw new Error('Input types are mismatched');