onnxruntime

Форк
0
162 строки · 5.4 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { GemmUtil, ShapeUtil } from '../../util';
7
import { AttributeWithCacheKey } from '../attribute-with-cache-key';
8
import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types';
9

10
import {
11
  createTensorShapeVariables,
12
  IndicesHelper,
13
  inputVariable,
14
  outputVariable,
15
  ShaderHelper,
16
  UniformsArrayType,
17
} from './common';
18

19
const validateInputs = (inputs: readonly TensorView[]): void => {
20
  if (!inputs) {
21
    throw new Error('Input is missing');
22
  }
23
  if (inputs.length < 2 || inputs.length > 3) {
24
    throw new Error('Invaid input number.');
25
  }
26

27
  // 'C' can be of dimensionality 0, 1 or 2 only
28
  if (inputs.length === 3 && inputs[2].dims.length > 2) {
29
    throw new Error('Invalid input shape of C');
30
  }
31

32
  if (inputs[0].dataType !== inputs[1].dataType || (inputs.length === 3 && inputs[0].dataType !== inputs[2].dataType)) {
33
    throw new Error('Input types are mismatched');
34
  }
35
};
36

37
export interface GemmAttributes extends AttributeWithCacheKey {
38
  transA: boolean;
39
  transB: boolean;
40
  alpha: number;
41
  beta: number;
42
}
43

44
const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAttributes): ProgramInfo => {
45
  const aShape = inputs[0].dims.slice();
46
  const bShape = inputs[1].dims.slice();
47
  const [M, N, K] = GemmUtil.getShapeOfGemmResult(
48
    aShape,
49
    attributes.transA,
50
    bShape,
51
    attributes.transB,
52
    inputs.length === 3 ? inputs[2].dims : undefined,
53
  );
54
  const outputShape = [M, N];
55
  if (!outputShape) {
56
    throw new Error("Can't use gemm on the given tensors");
57
  }
58
  const outputSize = ShapeUtil.size(outputShape);
59
  const programUniforms: ProgramUniform[] = [
60
    { type: DataType.uint32, data: outputSize },
61
    { type: DataType.uint32, data: M },
62
    { type: DataType.uint32, data: N },
63
    { type: DataType.uint32, data: K },
64
    { type: DataType.float, data: attributes.alpha },
65
    { type: DataType.float, data: attributes.beta },
66
  ];
67
  const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
68
  if (inputs.length === 3) {
69
    programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
70
    inputDependencies.push('rank');
71
  }
72
  programUniforms.push(...createTensorShapeVariables(outputShape));
73

74
  const getShaderSource = (shaderHelper: ShaderHelper) => {
75
    let line = '';
76
    if (attributes.transA && attributes.transB) {
77
      line = 'value += a[k * uniforms.M + m] * b[n * uniforms.K + k];';
78
    } else if (attributes.transA && !attributes.transB) {
79
      line = 'value += a[k * uniforms.M + m] * b[k * uniforms.N + n];';
80
    } else if (!attributes.transA && attributes.transB) {
81
      line = 'value += a[m * uniforms.K + k] * b[n * uniforms.K + k];';
82
    } else if (!attributes.transA && !attributes.transB) {
83
      line = 'value += a[m * uniforms.K + k] * b[k * uniforms.N + n];';
84
    }
85

86
    const calculateAlpha = attributes.alpha === 1 ? '' : 'value *= uniforms.alpha;';
87
    const a = inputVariable('a', inputs[0].dataType, inputs[0].dims);
88
    const b = inputVariable('b', inputs[1].dataType, inputs[1].dims);
89
    const dataType = a.type.value;
90
    let c: IndicesHelper | null = null;
91
    const variables = [a, b];
92
    if (inputs.length === 3) {
93
      c = inputVariable('c', inputs[2].dataType, inputs[2].dims.length);
94
      variables.push(c);
95
    }
96
    const output = outputVariable('output', inputs[0].dataType, outputShape.length);
97
    variables.push(output);
98
    const uniforms: UniformsArrayType = [
99
      { name: 'output_size', type: 'u32' },
100
      { name: 'M', type: 'u32' },
101
      { name: 'N', type: 'u32' },
102
      { name: 'K', type: 'u32' },
103
      { name: 'alpha', type: 'f32' },
104
      { name: 'beta', type: 'f32' },
105
    ];
106
    return `
107
  ${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)}
108

109
  ${shaderHelper.mainStart()}
110
    ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
111

112
    let m = global_idx / uniforms.N;
113
    let n = global_idx % uniforms.N;
114

115
    var value = ${dataType}(0);
116
    for (var k: u32 = 0u; k < uniforms.K; k++) {
117
      ${line}
118
    }
119

120
    ${calculateAlpha}
121
    ${(() => {
122
      if (c != null) {
123
        return `let cOffset = ${c.broadcastedIndicesToOffset('vec2(m, n)', output)}; value += ${
124
          dataType
125
        }(uniforms.beta) * ${c.getByOffset('cOffset')};`;
126
      }
127
      return '';
128
    })()}
129
    output[global_idx] = value;
130
  }`;
131
  };
132

133
  return {
134
    name: 'Gemm',
135
    shaderCache: { hint: `${attributes.cacheKey}`, inputDependencies },
136
    getRunData: () => ({
137
      outputs: [{ dims: outputShape, dataType: inputs[0].dataType }],
138
      dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
139
      programUniforms,
140
    }),
141
    getShaderSource,
142
  };
143
};
144

145
export const parseGemmAttributes = (attributes: Record<string, unknown>): GemmAttributes => {
146
  const transA = attributes.transA as boolean;
147
  const transB = attributes.transB as boolean;
148
  const alpha = attributes.alpha as number;
149
  const beta = attributes.beta as number;
150
  return {
151
    transA,
152
    transB,
153
    alpha,
154
    beta,
155
    cacheKey: `${attributes.transA};${attributes.transB};${attributes.alpha === 1}`,
156
  };
157
};
158

159
export const gemm = (context: ComputeContext, attributes: GemmAttributes): void => {
160
  validateInputs(context.inputs);
161
  context.compute(createGemmProgramInfo(context.inputs, attributes));
162
};
163

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.