onnxruntime

Форк
0
153 строки · 5.3 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
5
import { Graph } from '../../../graph';
6
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
7
import { Tensor } from '../../../tensor';
8
import { GemmUtil } from '../../../util';
9
import { WebGLInferenceHandler } from '../inference-handler';
10
import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types';
11

12
export interface GemmAttributes extends AttributeWithCacheKey {
13
  transA: boolean;
14
  transB: boolean;
15
  alpha: number;
16
  beta: number;
17
  isOptionalC: boolean; // in opset 11, C becomes optional
18
}
19

20
export const gemm: OperatorImplementation<GemmAttributes> = (
21
  inferenceHandler: WebGLInferenceHandler,
22
  inputs: Tensor[],
23
  attributes: GemmAttributes,
24
): Tensor[] => {
25
  validateInputs(inputs, attributes);
26
  const output = inferenceHandler.run(createGemmProgramInfoLoader(inputs, attributes), inputs);
27
  return [output];
28
};
29

30
const parseGemmAttributes = (node: Graph.Node, isOptionalC: boolean): GemmAttributes => {
31
  const transA = node.attributes.getInt('transA', 0) !== 0;
32
  const transB = node.attributes.getInt('transB', 0) !== 0;
33
  const alpha = node.attributes.getFloat('alpha', 1.0);
34
  const beta = node.attributes.getFloat('beta', 1.0);
35
  return createAttributeWithCacheKey({ transA, transB, alpha, beta, isOptionalC });
36
};
37

38
export const parseGemmAttributesV7: OperatorInitialization<GemmAttributes> = (node: Graph.Node): GemmAttributes =>
39
  parseGemmAttributes(node, false);
40

41
export const parseGemmAttributesV11: OperatorInitialization<GemmAttributes> = (node: Graph.Node): GemmAttributes =>
42
  parseGemmAttributes(node, true);
43

44
const createGemmProgramInfoLoader = (inputs: Tensor[], attributes: GemmAttributes): ProgramInfoLoader => {
45
  const metadata = {
46
    name: 'Gemm',
47
    inputNames: inputs.length === 3 ? ['A', 'B', 'C'] : ['A', 'B'],
48
    inputTypes:
49
      inputs.length === 3
50
        ? [TextureType.unpacked, TextureType.unpacked, TextureType.unpacked]
51
        : [TextureType.unpacked, TextureType.unpacked],
52
    key: attributes.cacheKey,
53
  };
54

55
  return { ...metadata, get: () => createGemmProgramInfo(metadata, inputs, attributes) };
56
};
57

58
const createGemmProgramInfo = (
59
  metadata: ProgramMetadata,
60
  inputs: Tensor[],
61
  attributes: GemmAttributes,
62
): ProgramInfo => {
63
  const aShape = inputs[0].dims.slice();
64
  const bShape = inputs[1].dims.slice();
65
  const [M, N] = GemmUtil.getShapeOfGemmResult(
66
    aShape,
67
    attributes.transA,
68
    bShape,
69
    attributes.transB,
70
    inputs.length === 3 ? inputs[2].dims : undefined,
71
  );
72
  const outputShape = [M, N];
73
  if (!outputShape) {
74
    throw new Error("Can't use gemm on the given tensors");
75
  }
76
  let sharedDim = aShape[aShape.length - 1];
77
  let line = '';
78
  if (attributes.transA) {
79
    sharedDim = aShape[0];
80
  }
81
  if (attributes.transA && attributes.transB) {
82
    line = 'value += _A_T(a) * _B_T(b);';
83
  } else if (attributes.transA && !attributes.transB) {
84
    line = 'value += _A_T(a) * _B(b);';
85
  } else if (!attributes.transA && attributes.transB) {
86
    line = 'value += _A(a) * _B_T(b);';
87
  } else if (!attributes.transA && !attributes.transB) {
88
    line = 'value += _A(a) * _B(b);';
89
  }
90
  const rank = outputShape.length;
91
  const declareC = inputs.length === 3 ? `int c[${inputs[2].dims.length}];` : '';
92
  const broadcastC = inputs.length === 3 ? 'bcastIndices_C(indices, c);' : '';
93
  const calculateC = inputs.length === 3 ? 'value += beta * _C(c);' : '';
94
  const shaderSource = `
95
      float process(int indices[${rank}]) {
96
          int a[${rank}];
97
          int b[${rank}];
98
          ${declareC}
99

100
          copyVec(indices, a);
101
          copyVec(indices, b);
102
          ${broadcastC}
103

104
          float value = 0.0;
105
          for (int k=0; k<${sharedDim}; ++k) {
106
              a[${rank - 1}] = k;
107
              b[${rank - 2}] = k;
108
              ${line}
109
          }
110

111
          value = value * alpha;
112
          ${calculateC}
113
          return value;
114
      }`;
115
  return {
116
    ...metadata,
117
    output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked },
118
    variables: [
119
      { name: 'alpha', type: 'float', data: attributes.alpha },
120
      { name: 'beta', type: 'float', data: attributes.beta },
121
    ],
122
    shaderSource,
123
  };
124
};
125

126
const validateInputs = (inputs: Tensor[], attributes: GemmAttributes): void => {
127
  if (!inputs) {
128
    throw new Error('Input is missing');
129
  }
130
  if (attributes.isOptionalC && (inputs.length < 2 || inputs.length > 3)) {
131
    throw new Error('Invaid input shape.');
132
  }
133
  if (!attributes.isOptionalC && inputs.length !== 3) {
134
    throw new Error('Gemm requires 3 inputs');
135
  }
136

137
  // 'C' can be of dimensionality 1 or 2 only
138
  if (inputs.length === 3 && inputs[2].dims.length !== 1 && inputs[2].dims.length !== 2) {
139
    throw new Error('Invalid input shape of C');
140
  }
141

142
  if (
143
    (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') ||
144
    (inputs[1].type !== 'float32' && inputs[1].type !== 'float64') ||
145
    (inputs.length === 3 && inputs[2].type !== 'float32' && inputs[2].type !== 'float64')
146
  ) {
147
    throw new Error('Invalid input type.');
148
  }
149

150
  if (inputs[0].type !== inputs[1].type || (inputs.length === 3 && inputs[0].type !== inputs[2].type)) {
151
    throw new Error('Input types are mismatched');
152
  }
153
};
154

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.