onnxruntime

Форк
0
206 строк · 7.9 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { BroadcastUtil, ShapeUtil } from '../../util';
7
import { ComputeContext, ProgramInfo, ProgramUniform } from '../types';
8

9
import { createMatmulProgramInfo } from './3rd-party/matmul_packed_webgpu';
10
import {
11
  createTensorShapeVariables,
12
  getBroadcastDims,
13
  getMaxComponents,
14
  IndicesHelper,
15
  inputVariable,
16
  internalVariable,
17
  outputVariable,
18
  ShaderHelper,
19
  tensorTypeToWsglStorageType,
20
  UniformsArrayType,
21
} from './common';
22
import {
23
  appendActivationUniforms,
24
  appendActivationUniformsData,
25
  getActivationSnippet,
26
  InternalActivationAttributes,
27
} from './fuse-utils';
28

29
export const createNaiveMatmulProgramInfo = (
30
  inputs: readonly TensorView[],
31
  activationAttributes: InternalActivationAttributes,
32
  outputShape: readonly number[],
33
  reshapedOutputShape?: readonly number[],
34
  isChannelsLast = false /* only used for conv2dByMatMul*/,
35
  squeezeOutputShapeFunction?: (shape: readonly number[]) => number[],
36
): ProgramInfo => {
37
  const aShape = inputs[0].dims;
38
  const bShape = inputs[1].dims;
39

40
  const M = aShape[aShape.length - 2];
41
  const N = bShape[bShape.length - 1];
42
  const K = aShape[aShape.length - 1];
43
  const components = getMaxComponents(N);
44
  const aComponents = getMaxComponents(K);
45
  const outputNumber = getMaxComponents(M);
46
  const outputSize = ShapeUtil.size(outputShape) / components / outputNumber;
47
  const hasBias = inputs.length > 2;
48
  const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2);
49
  const batchSize = ShapeUtil.size(outerDims);
50
  const outputShapeInShader = [batchSize, M, N];
51

52
  const programUniforms: ProgramUniform[] = [
53
    { type: DataType.uint32, data: outputSize },
54
    { type: DataType.uint32, data: M },
55
    { type: DataType.uint32, data: N },
56
    { type: DataType.uint32, data: K },
57
  ];
58
  appendActivationUniformsData(activationAttributes, programUniforms);
59
  programUniforms.push(...createTensorShapeVariables(outerDims, aShape, bShape));
60
  if (hasBias) {
61
    programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
62
  }
63
  programUniforms.push(...createTensorShapeVariables(outputShapeInShader));
64

65
  const getShaderSource = (shaderHelper: ShaderHelper) => {
66
    const batchDims = internalVariable('batch_dims', inputs[0].dataType, outerDims.length);
67
    const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents);
68
    const b = inputVariable('b', inputs[1].dataType, bShape.length, components);
69
    const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components);
70
    const baseType = tensorTypeToWsglStorageType(output.type.tensor);
71
    const applyActivation = getActivationSnippet(activationAttributes, output.type.value, baseType);
72
    const inputVariables = [a, b];
73
    let processBias = '';
74
    if (hasBias) {
75
      const biasComponents = isChannelsLast ? components : 1;
76
      inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents));
77
      processBias = `${
78
        isChannelsLast ? `value += bias[col / ${biasComponents}];` : `value += ${output.type.value}(bias[row + i]);`
79
      }`;
80
    }
81

82
    const outerDimsA = aShape.slice(0, -2);
83
    const outerDimsB = bShape.slice(0, -2);
84
    const broadCastADims = getBroadcastDims(outerDimsA, outerDims);
85
    const broadCastBDims = getBroadcastDims(outerDimsB, outerDims);
86
    const uniforms: UniformsArrayType = [
87
      { name: 'output_size', type: 'u32' },
88
      { name: 'M', type: 'u32' },
89
      { name: 'N', type: 'u32' },
90
      { name: 'K', type: 'u32' },
91
    ];
92
    appendActivationUniforms(activationAttributes, uniforms);
93

94
    const getIndices = (variable: IndicesHelper, broadCastDims: number[]) => {
95
      const rank = variable.rank;
96
      const name = variable.name;
97
      if (rank === 2) {
98
        return `var ${name}_indices = ${variable.type.indices}(0u, 0u);`;
99
      }
100
      const batchRank = batchDims.rank;
101
      let resStr = `var ${name}_indices: ${variable.type.indices};`;
102
      for (let i = rank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) {
103
        resStr += `\n${name}_indices[${i}] = ${batchRank > 1 ? `batch_indices[${j}]` : 'batch_indices'};`;
104
      }
105
      broadCastDims.forEach((i) => {
106
        resStr += `\n${name}_indices[${i}] = 0;`;
107
      });
108
      resStr += `${name}_indices[${rank - 2}] = 0u;
109
                     ${name}_indices[${rank - 1}] = 0u;`;
110
      return resStr;
111
    };
112

113
    const calcResult = (): string => {
114
      let calcStr = `var a_data: ${a.type.value};`;
115
      for (let i = 0; i < aComponents; i++) {
116
        calcStr += `
117
              let b_data${i} = b[(b_offset + (k + ${i}) * uniforms.N + col) / ${components}];`;
118
      }
119
      for (let i = 0; i < outputNumber; i++) {
120
        calcStr += `a_data = a[(a_offset + (row + ${i}) * uniforms.K + k) / ${aComponents}];`;
121

122
        for (let j = 0; j < aComponents; j++) {
123
          calcStr += `
124
            values[${i}] = fma(${b.type.value}(a_data${aComponents === 1 ? '' : `[${j}]`}), b_data${j}, values[${i}]);\n`;
125
        }
126
      }
127
      return calcStr;
128
    };
129

130
    return `
131
  ${shaderHelper
132
    .registerUniforms(uniforms)
133
    .registerInternalVariables(batchDims)
134
    .declareVariables(...inputVariables, output)}
135
  ${shaderHelper.mainStart()}
136
    ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
137
    let col = (global_idx % (uniforms.N / ${components})) * ${components};
138
    var index1 = global_idx / (uniforms.N / ${components});
139
    let stride1 = uniforms.M / ${outputNumber};
140
    let row = (index1 % stride1) * ${outputNumber};
141
    let batch = index1 / stride1;
142

143
    ${outputShape.length === 2 ? '' : `let batch_indices = ${batchDims.offsetToIndices('batch')};`}
144
    ${getIndices(a, broadCastADims)}
145
    let a_offset = ${a.indicesToOffset('a_indices')};
146
    ${getIndices(b, broadCastBDims)}
147
    let b_offset = ${b.indicesToOffset('b_indices')};
148
    var values: array<${output.type.value}, ${outputNumber}>;
149
    for (var k: u32 = 0u; k < uniforms.K; k = k + ${aComponents}) {
150
      ${calcResult()}
151
    }
152
    for (var i = 0u; i < ${outputNumber}u; i++) {
153
      var value = values[i];
154
      ${processBias}
155
      ${applyActivation}
156
      let cur_indices = ${output.type.indices}(batch, row + i, col);
157
      let offset = ${output.indicesToOffset('cur_indices')};
158
      ${output.setByOffset(`offset / ${components}`, 'value')};
159
    }
160
  }
161
  `;
162
  };
163
  return {
164
    name: 'MatMulNaive',
165
    shaderCache: {
166
      hint: `${activationAttributes.activation};${components};${aComponents};${outputNumber};${isChannelsLast}`,
167
      inputDependencies: hasBias ? ['rank', 'rank', 'rank'] : ['rank', 'rank'],
168
    },
169
    getRunData: () => ({
170
      outputs: [
171
        {
172
          dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape,
173
          dataType: inputs[0].dataType,
174
        },
175
      ],
176
      dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
177
      programUniforms,
178
    }),
179
    getShaderSource,
180
  };
181
};
182

183
const validateInputs = (inputs: readonly TensorView[]): void => {
184
  if (!inputs || inputs.length !== 2) {
185
    throw new Error('MatMul requires 2 inputs.');
186
  }
187

188
  if (inputs[0].dims[inputs[0].dims.length - 1] !== inputs[1].dims[inputs[1].dims.length - 2]) {
189
    throw new Error('shared dimension does not match.');
190
  }
191
};
192

193
export const matMul = (context: ComputeContext): void => {
194
  validateInputs(context.inputs);
195
  const outputShape = BroadcastUtil.calcShape(context.inputs[0].dims, context.inputs[1].dims, true);
196
  if (!outputShape) {
197
    throw new Error("Can't use matmul on the given tensors");
198
  }
199
  const N = outputShape[outputShape.length - 1];
200
  const K = context.inputs[0].dims[context.inputs[0].dims.length - 1];
201
  if (N < 8 && K < 8) {
202
    context.compute(createNaiveMatmulProgramInfo(context.inputs, { activation: '' }, outputShape));
203
  } else {
204
    context.compute(createMatmulProgramInfo(context.inputs, { activation: '' }, outputShape));
205
  }
206
};
207

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.