onnxruntime

Форк
0
176 строк · 6.3 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { Tensor } from '../../../tensor';
5
import { BroadcastUtil, ShapeUtil } from '../../../util';
6
import { getGlsl } from '../glsl-source';
7
import { WebGLInferenceHandler } from '../inference-handler';
8
import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types';
9
import { getCoordsDataType, getGlChannels } from '../utils';
10

11
import { getActivationSnippet, InternalActivationAttributes } from './fuse-utils';
12
import { getBiasForMatmul } from './matmul';
13

14
const createPackedMatmulProgramMetadata = (hasBias: boolean, cacheHint: string) => ({
15
  name: 'MatMul (packed)',
16
  inputNames: hasBias ? ['A', 'B', 'Bias'] : ['A', 'B'],
17
  inputTypes: hasBias
18
    ? [TextureType.packed, TextureType.packed, TextureType.packed]
19
    : [TextureType.packed, TextureType.packed],
20
  cacheHint,
21
});
22

23
const createPackedMatmulProgramInfo = (
24
  inferenceHandler: WebGLInferenceHandler,
25
  metadata: ProgramMetadata,
26
  inputs: Tensor[],
27
  activationAttributes: InternalActivationAttributes,
28
): ProgramInfo => {
29
  const hasBias = inputs.length > 2;
30
  const processBias = hasBias ? 'value += getBiasForMatmul();' : '';
31
  const aShape = inputs[0].dims;
32
  const bShape = inputs[1].dims;
33
  const outputShape = BroadcastUtil.calcShape(aShape, bShape, true);
34
  const isBroadcast = !ShapeUtil.areEqual(inputs[0].dims, inputs[1].dims);
35

36
  if (!outputShape) {
37
    throw new Error("Can't use matmul on the given tensors");
38
  }
39
  const sharedDim = aShape[aShape.length - 1];
40
  const sharedDimIndex = Math.ceil(sharedDim / 2);
41
  const aRank = aShape.length;
42
  const bRank = bShape.length;
43

44
  const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
45
  const coordsDataType = getCoordsDataType(outputShape.length);
46
  const outRank = outputShape.length;
47
  const allGlChannels = getGlChannels();
48
  const { activationFunction, applyActivation } = getActivationSnippet(activationAttributes);
49

50
  const getBiasForMatmulSnippet = hasBias
51
    ? `${getBiasForMatmul(coordsDataType, allGlChannels, inputs[2].dims, outputShape, true)}`
52
    : '';
53

54
  const getBcastedSamplerForMatmulSnippet = isBroadcast
55
    ? `${getBcastSamplerForMatmul(coordsDataType, allGlChannels, inputs, outputShape)}`
56
    : '';
57

58
  const getSamplerAInLoopSnippet = isBroadcast ? 'getAAtOutCoordsMatmul(i)' : `getA(${getA(allGlChannels, aRank)})`;
59
  const getSamplerBInLoopSnippet = isBroadcast ? 'getBAtOutCoordsMatmul(i)' : `getB(${getB(allGlChannels, bRank)})`;
60
  const getOutputCoordsSnippet = isBroadcast
61
    ? ''
62
    : `${coordsDataType} rc =
63
          getOutputCoords(); int lastDim = rc.${allGlChannels[outRank - 1]}; rc.${allGlChannels[outRank - 1]} =
64
          rc.${allGlChannels[outRank - 2]}; rc.${allGlChannels[outRank - 2]} = lastDim;
65
      `;
66
  const shaderSource = `
67
            ${getBcastedSamplerForMatmulSnippet}
68
            ${getBiasForMatmulSnippet}
69
            ${activationFunction}
70
            void main() {
71
              ${getOutputCoordsSnippet}
72

73
              vec4 value = vec4(0);
74
              for (int i = 0; i < ${sharedDimIndex}; i++) {
75
                vec4 a = ${getSamplerAInLoopSnippet};
76
                vec4 b = ${getSamplerBInLoopSnippet};
77

78
                value += (a.rrbb * b.rgrg);
79
                value += (a.ggaa * b.baba);
80
              }
81
              ${processBias}
82
              ${applyActivation}
83
              ${glsl.output} = value;
84
            }`;
85
  return {
86
    ...metadata,
87
    output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.packed },
88
    shaderSource,
89
    hasMain: true,
90
  };
91
};
92

93
export const createPackedMatmulProgramInfoLoader = (
94
  inferenceHandler: WebGLInferenceHandler,
95
  inputs: Tensor[],
96
  activationAttributes: InternalActivationAttributes,
97
): ProgramInfoLoader => {
98
  const metadata = createPackedMatmulProgramMetadata(inputs.length > 2, activationAttributes.activationCacheKey);
99
  return {
100
    ...metadata,
101
    get: () => createPackedMatmulProgramInfo(inferenceHandler, metadata, inputs, activationAttributes),
102
  };
103
};
104

105
function getBcastSamplerForMatmul(
106
  coordsDataType: string,
107
  allGlChannels: readonly string[],
108
  inputs: Tensor[],
109
  outShape: readonly number[],
110
): string {
111
  let unpackedACoordsSnippet = [];
112
  let unpackedBCoordsSnippet = [];
113

114
  const inAShape = inputs[0].dims;
115
  const inBShape = inputs[1].dims;
116

117
  const inARank = inAShape.length;
118
  const inBRank = inBShape.length;
119

120
  const outRank = outShape.length;
121
  const rankADiff = outRank - inARank;
122
  const rankBDiff = outRank - inBRank;
123

124
  unpackedACoordsSnippet = inAShape.map((_s, i) => `coords.${allGlChannels[i + rankADiff]}`);
125
  unpackedACoordsSnippet[inARank - 1] = 'i*2';
126
  unpackedACoordsSnippet.join(', ');
127
  unpackedBCoordsSnippet = inBShape.map((_s, i) => `coords.${allGlChannels[i + rankBDiff]}`);
128
  unpackedBCoordsSnippet[inBRank - 2] = 'i*2';
129
  unpackedBCoordsSnippet.join(', ');
130

131
  const broadcastADims = BroadcastUtil.getBroadcastDims(inAShape, outShape);
132
  const broadcastBDims = BroadcastUtil.getBroadcastDims(inBShape, outShape);
133

134
  const coordsASnippet = broadcastADims.map((d) => `coords.${allGlChannels[d + rankADiff]} = 0;`).join('\n');
135
  const coordsBSnippet = broadcastBDims.map((d) => `coords.${allGlChannels[d + rankBDiff]} = 0;`).join('\n');
136
  const swapDimSnippet = `int lastDim = coords.${allGlChannels[outRank - 1]};
137
  coords.${allGlChannels[outRank - 1]} = coords.${allGlChannels[outRank - 2]};
138
  coords.${allGlChannels[outRank - 2]} = lastDim;`;
139

140
  const getBcastSamplerMatmulSource = `
141
vec4 getAAtOutCoordsMatmul(int i) {
142
  ${coordsDataType} coords = getOutputCoords();
143
  ${swapDimSnippet}
144
  ${coordsASnippet}
145
  vec4 outputValue = getA(${unpackedACoordsSnippet});
146
  return outputValue;
147
}
148

149
vec4 getBAtOutCoordsMatmul(int i) {
150
  ${coordsDataType} coords = getOutputCoords();
151
  ${swapDimSnippet}
152
  ${coordsBSnippet}
153
  vec4 outputValue = getB(${unpackedBCoordsSnippet});
154
  return outputValue;
155
}`;
156

157
  return getBcastSamplerMatmulSource;
158
}
159

160
function getA(allGlChannels: string[], rank: number): string {
161
  let res = '';
162
  for (let i = 0; i < rank - 2; i++) {
163
    res += `rc.${allGlChannels[i]}, `;
164
  }
165
  res += `rc.${allGlChannels[rank - 2]}, ` + 'i*2';
166
  return res;
167
}
168

169
function getB(allGlChannels: string[], rank: number): string {
170
  let res = '';
171
  for (let i = 0; i < rank - 2; i++) {
172
    res += `rc.${allGlChannels[i]}, `;
173
  }
174
  res += 'i*2, ' + `rc.${allGlChannels[rank - 1]}`;
175
  return res;
176
}
177

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.