onnxruntime

Форк
0
208 строк · 7.0 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
5
import { Graph } from '../../../graph';
6
import { NUMBER_TYPES, OperatorImplementation, OperatorInitialization } from '../../../operators';
7
import { Tensor } from '../../../tensor';
8
import { ShapeUtil } from '../../../util';
9
import { WebGLInferenceHandler } from '../inference-handler';
10
import { ProgramInfo, ProgramMetadata, TextureType } from '../types';
11

12
export interface ReduceAttributes extends AttributeWithCacheKey {
13
  readonly axes: number[];
14
  readonly keepDims: boolean;
15
}
16

17
// return [init ops, reduce ops, final ops]
18
type ReduceOp = (inputs: Tensor[], axes: number[]) => string[];
19

20
const reduce = (
21
  inferenceHandler: WebGLInferenceHandler,
22
  inputs: Tensor[],
23
  attributes: ReduceAttributes,
24
  name: string,
25
  reduceOp: ReduceOp,
26
): Tensor[] => {
27
  validateInputs(inputs);
28

29
  const reduceProgramMetadata = {
30
    name,
31
    inputNames: ['A'],
32
    inputTypes: [TextureType.unpacked],
33
  };
34

35
  const output = inferenceHandler.run(
36
    {
37
      ...reduceProgramMetadata,
38
      cacheHint: attributes.cacheKey,
39
      get: () => createReduceProgramInfo(inferenceHandler, inputs, attributes, name, reduceOp, reduceProgramMetadata),
40
    },
41
    inputs,
42
  );
43
  return [output];
44
};
45

46
export const parseReduceAttributes: OperatorInitialization<ReduceAttributes> = (node: Graph.Node): ReduceAttributes => {
47
  const axes = node.attributes.getInts('axes', []);
48
  const keepDims = node.attributes.getInt('keepdims', 1) === 1;
49
  return createAttributeWithCacheKey({ axes, keepDims });
50
};
51

52
const createReduceProgramInfo = (
53
  _handler: WebGLInferenceHandler,
54
  inputs: Tensor[],
55
  attributes: ReduceAttributes,
56
  _name: string,
57
  reduceOp: ReduceOp,
58
  reduceProgramMetadata: ProgramMetadata,
59
): ProgramInfo => {
60
  const outputShape: number[] = [];
61
  const iRank = inputs[0].dims.length || 1;
62

63
  const idxCopy = []; // copy output indexes to input indexes
64

65
  const axes = ShapeUtil.normalizeAxes(attributes.axes, inputs[0].dims.length);
66
  const ops = reduceOp(inputs, axes);
67
  let reduceOps = ops[1];
68

69
  for (let k = 0; k < inputs[0].dims.length; k++) {
70
    // if this axis is reduced
71
    if (axes.indexOf(k) >= 0 || axes.length === 0) {
72
      if (attributes.keepDims) {
73
        outputShape.push(1);
74
      } // else { remove the axis from outputShape; }
75

76
      // loop over the d-th axis
77
      reduceOps = `
78
          for(int j${k} = 0; j${k} < ${inputs[0].dims[k]}; j${k}++) {
79
            inputIdx[${k}] = j${k};
80
            ${reduceOps}
81
          }`;
82
    } else {
83
      idxCopy.push(`inputIdx[${k}] = outputIdx[${outputShape.length}];`);
84

85
      outputShape.push(inputs[0].dims[k]);
86
    }
87
  }
88

89
  const oRank = outputShape.length || 1;
90

91
  const shaderSource = `
92
      float process(int outputIdx[${oRank}]) {
93
        float value;                 // final result
94
        int inputIdx[${iRank}];      // addressing input data
95
        ${idxCopy.join('\n')}
96
        ${ops[0]}       // init ops for reduce max/min
97
        ${reduceOps}
98
        ${ops[2]}       // final computation for reduce mean
99
        return value;
100
      }`;
101

102
  return {
103
    ...reduceProgramMetadata,
104
    output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked },
105
    shaderSource,
106
  };
107
};
108

109
const validateInputs = (inputs: Tensor[]): void => {
110
  // TODO: support Reduce* operators with 2 inputs.
111
  if (!inputs || inputs.length !== 1) {
112
    throw new Error('Reduce op requires 1 input.');
113
  }
114

115
  if (NUMBER_TYPES.indexOf(inputs[0].type) === -1) {
116
    throw new Error('Invalid input type.');
117
  }
118
};
119

120
export const reduceSum: OperatorImplementation<ReduceAttributes> = (
121
  inferenceHandler: WebGLInferenceHandler,
122
  inputs: Tensor[],
123
  attributes: ReduceAttributes,
124
): Tensor[] => {
125
  const reduceOp: ReduceOp = (): string[] => ['value = 0.0;', 'value += _A(inputIdx);', ''];
126
  return reduce(inferenceHandler, inputs, attributes, 'ReduceSum', reduceOp);
127
};
128

129
export const reduceMean: OperatorImplementation<ReduceAttributes> = (
130
  inferenceHandler: WebGLInferenceHandler,
131
  inputs: Tensor[],
132
  attributes: ReduceAttributes,
133
): Tensor[] => {
134
  const reduceOp: ReduceOp = (inputs: Tensor[], axes: number[]): string[] => {
135
    let size = 1.0;
136
    for (let k = 0; k < inputs[0].dims.length; k++) {
137
      if (axes.indexOf(k) >= 0 || axes.length === 0) {
138
        size *= inputs[0].dims[k];
139
      }
140
    }
141

142
    return ['value = 0.0;', 'value += _A(inputIdx);', `value /= ${size}.;`]; // ensure real number with `.`
143
  };
144
  return reduce(inferenceHandler, inputs, attributes, 'ReduceMean', reduceOp);
145
};
146

147
export const reduceMax: OperatorImplementation<ReduceAttributes> = (
148
  inferenceHandler: WebGLInferenceHandler,
149
  inputs: Tensor[],
150
  attributes: ReduceAttributes,
151
): Tensor[] => {
152
  const reduceOp: ReduceOp = (inputs: Tensor[], axes: number[]): string[] => {
153
    const idxZero = [];
154
    for (let k = 0; k < inputs[0].dims.length; k++) {
155
      if (axes.indexOf(k) >= 0 || axes.length === 0) {
156
        idxZero.push(`inputIdx[${k}] = 0;`); // first element
157
      }
158
    }
159

160
    return [`${idxZero.join('\n')}\nvalue = _A(inputIdx);`, 'value = max(value, _A(inputIdx));', ''];
161
  };
162
  return reduce(inferenceHandler, inputs, attributes, 'ReduceMax', reduceOp);
163
};
164

165
export const reduceMin: OperatorImplementation<ReduceAttributes> = (
166
  inferenceHandler: WebGLInferenceHandler,
167
  inputs: Tensor[],
168
  attributes: ReduceAttributes,
169
): Tensor[] => {
170
  const reduceOp: ReduceOp = (inputs: Tensor[], axes: number[]): string[] => {
171
    const idxZero = [];
172
    for (let k = 0; k < inputs[0].dims.length; k++) {
173
      if (axes.indexOf(k) >= 0 || axes.length === 0) {
174
        idxZero.push(`inputIdx[${k}] = 0;`); // first element
175
      }
176
    }
177

178
    return [`${idxZero.join('\n')}\nvalue = _A(inputIdx);`, 'value = min(value, _A(inputIdx));', ''];
179
  };
180
  return reduce(inferenceHandler, inputs, attributes, 'ReduceMin', reduceOp);
181
};
182

183
export const reduceProd: OperatorImplementation<ReduceAttributes> = (
184
  inferenceHandler: WebGLInferenceHandler,
185
  inputs: Tensor[],
186
  attributes: ReduceAttributes,
187
): Tensor[] => {
188
  const reduceOp: ReduceOp = (): string[] => ['value = 1.0;', 'value *= _A(inputIdx);', ''];
189
  return reduce(inferenceHandler, inputs, attributes, 'ReduceProd', reduceOp);
190
};
191

192
export const reduceLogSum: OperatorImplementation<ReduceAttributes> = (
193
  inferenceHandler: WebGLInferenceHandler,
194
  inputs: Tensor[],
195
  attributes: ReduceAttributes,
196
): Tensor[] => {
197
  const reduceOp: ReduceOp = (): string[] => ['value = 0.0;', 'value += _A(inputIdx);', 'value = log(value);'];
198
  return reduce(inferenceHandler, inputs, attributes, 'ReduceLogSum', reduceOp);
199
};
200

201
export const reduceLogSumSquare: OperatorImplementation<ReduceAttributes> = (
202
  inferenceHandler: WebGLInferenceHandler,
203
  inputs: Tensor[],
204
  attributes: ReduceAttributes,
205
): Tensor[] => {
206
  const reduceOp: ReduceOp = (): string[] => ['float t; value = 0.0;', 't = _A(inputIdx); value += t * t;', ''];
207
  return reduce(inferenceHandler, inputs, attributes, 'ReduceLogSumSquare', reduceOp);
208
};
209

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.