onnxruntime

Форк
0
162 строки · 5.5 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { Graph } from '../../../graph';
5
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
6
import { Tensor } from '../../../tensor';
7
import { BroadcastUtil, ShapeUtil } from '../../../util';
8
import { WebGLInferenceHandler } from '../inference-handler';
9
import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types';
10
import { getCoordsDataType, getGlChannels } from '../utils';
11

12
import { getActivationSnippet, InternalActivationAttributes, parseInternalActivationAttributes } from './fuse-utils';
13
import { createPackedMatmulProgramInfoLoader } from './matmul-pack';
14

15
export const matMul: OperatorImplementation<InternalActivationAttributes> = (
16
  inferenceHandler: WebGLInferenceHandler,
17
  inputs: Tensor[],
18
  attributes: InternalActivationAttributes,
19
): Tensor[] => {
20
  validateInputs(inputs);
21

22
  if (inferenceHandler.session.pack) {
23
    return [inferenceHandler.run(createPackedMatmulProgramInfoLoader(inferenceHandler, inputs, attributes), inputs)];
24
  } else {
25
    return [inferenceHandler.run(createMatmulProgramInfoLoader(inputs, attributes), inputs)];
26
  }
27
};
28

29
export const parseMatMulAttributes: OperatorInitialization<InternalActivationAttributes> = (
30
  node: Graph.Node,
31
): InternalActivationAttributes => parseInternalActivationAttributes(node.attributes);
32

33
const createMatmulProgramMetadata = (hasBias: boolean, cacheHint: string) => ({
34
  name: 'MatMul',
35
  inputNames: hasBias ? ['A', 'B', 'Bias'] : ['A', 'B'],
36
  inputTypes: hasBias
37
    ? [TextureType.unpacked, TextureType.unpacked, TextureType.unpacked]
38
    : [TextureType.unpacked, TextureType.unpacked],
39
  cacheHint,
40
});
41

42
function createMatmulProgramInfo(
43
  metadata: ProgramMetadata,
44
  inputs: Tensor[],
45
  activationAttributes: InternalActivationAttributes,
46
): ProgramInfo {
47
  const aShape = inputs[0].dims;
48
  const bShape = inputs[1].dims;
49
  const outputShape = BroadcastUtil.calcShape(aShape, bShape, true);
50
  if (!outputShape) {
51
    throw new Error("Can't use matmul on the given tensors");
52
  }
53
  const coordsDataType = getCoordsDataType(outputShape.length);
54
  const allGlChannels = getGlChannels();
55
  const { activationFunction, applyActivation } = getActivationSnippet(activationAttributes);
56

57
  const hasBias = inputs.length > 2;
58
  const processBias = hasBias ? 'value += getBiasForMatmul();' : '';
59
  const getBiasForMatmulSnippet = hasBias
60
    ? `${getBiasForMatmul(coordsDataType, allGlChannels, inputs[2].dims, outputShape, false)}`
61
    : '';
62

63
  const rank = outputShape.length;
64
  const arank = aShape.length;
65
  const brank = bShape.length;
66
  const sharedDim = aShape[aShape.length - 1];
67
  const shaderSource = `
68
    ${activationFunction}
69
    ${getBiasForMatmulSnippet}
70
    float process(int indices[${rank}]) {
71
        int a[${arank}];
72
        int b[${brank}];
73
        bcastMatmulIndices_A(indices, a);
74
        bcastMatmulIndices_B(indices, b);
75

76
        float value;
77
        for (int k=0; k<${sharedDim}; ++k) {
78
            a[${arank - 1}] = k;
79
            b[${brank - 2}] = k;
80
            value += _A(a) * _B(b);
81
        }
82
        ${processBias}
83
        ${applyActivation}
84
        return value;
85
    }`;
86
  return {
87
    ...metadata,
88
    output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked },
89
    shaderSource,
90
  };
91
}
92

93
export function createMatmulProgramInfoLoader(
94
  inputs: Tensor[],
95
  activationAttributes: InternalActivationAttributes,
96
): ProgramInfoLoader {
97
  const metadata = createMatmulProgramMetadata(inputs.length > 2, activationAttributes.activationCacheKey);
98
  return { ...metadata, get: () => createMatmulProgramInfo(metadata, inputs, activationAttributes) };
99
}
100

101
const validateInputs = (inputs: Tensor[]): void => {
102
  if (!inputs || inputs.length !== 2) {
103
    throw new Error('MatMul requires 2 inputs.');
104
  }
105

106
  if (inputs[0].dims[inputs[0].dims.length - 1] !== inputs[1].dims[inputs[1].dims.length - 2]) {
107
    throw new Error('shared dimension does not match.');
108
  }
109

110
  if (
111
    (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') ||
112
    (inputs[1].type !== 'float32' && inputs[1].type !== 'float64')
113
  ) {
114
    throw new Error('inputs should be float type');
115
  }
116

117
  if (inputs[0].type !== inputs[1].type) {
118
    throw new Error('inputs types should match');
119
  }
120
};
121

122
export function getBiasForMatmul(
123
  coordsDataType: string,
124
  allGlChannels: readonly string[],
125
  inShape: readonly number[],
126
  outShape: readonly number[],
127
  isPacked: boolean,
128
): string {
129
  let unpackedCoordsSnippet = '';
130
  const inRank = inShape.length;
131
  const outRank = outShape.length;
132
  const rankDiff = outRank - inRank;
133
  if (outRank < 2 && inRank > 0) {
134
    unpackedCoordsSnippet = 'coords';
135
  } else {
136
    unpackedCoordsSnippet = inShape.map((_s, i) => `coords.${allGlChannels[i + rankDiff]}`).join(', ');
137
  }
138
  const broadcastDims = BroadcastUtil.getBroadcastDims(inShape, outShape);
139
  const coordsSnippet = broadcastDims.map((d) => `coords.${allGlChannels[d + rankDiff]} = 0;`).join('\n');
140
  const inSize = ShapeUtil.size(inShape);
141
  const isInputScalar = inSize === 1;
142
  let output = 'vec4(outputValue.xx, outputValue.yy)';
143
  if (isInputScalar) {
144
    output = 'vec4(outputValue.x)';
145
  }
146
  const getBiasForMatmulSource = isPacked
147
    ? `
148
vec4 getBiasForMatmul() {
149
  ${coordsDataType} coords = getOutputCoords();
150
  ${coordsSnippet}
151
  vec4 outputValue = getBias(${unpackedCoordsSnippet});
152
  return ${output};
153
}`
154
    : `
155
float getBiasForMatmul() {
156
  ${coordsDataType} coords = getOutputCoords();
157
  ${coordsSnippet}
158
  return getBias(coords.x);
159
}`;
160

161
  return getBiasForMatmulSource;
162
}
163

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.