onnxruntime

Форк
0
116 строк · 4.1 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
5
import { Graph } from '../../../graph';
6
import { NUMBER_TYPES, OperatorImplementation, OperatorInitialization } from '../../../operators';
7
import { Tensor } from '../../../tensor';
8
import { ShapeUtil } from '../../../util';
9
import { WebGLInferenceHandler } from '../inference-handler';
10
import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types';
11

12
interface GatherAttributes extends AttributeWithCacheKey {
13
  readonly axis: number;
14
}
15

16
export const gather: OperatorImplementation<GatherAttributes> = (
17
  inferenceHandler: WebGLInferenceHandler,
18
  inputs: Tensor[],
19
  attributes: GatherAttributes,
20
): Tensor[] => {
21
  validateInputs(inputs, attributes.axis);
22
  const output = inferenceHandler.run(createGatherProgramInfoLoader(inferenceHandler, inputs, attributes), inputs);
23
  return [output];
24
};
25

26
export const parseGatherAttributes: OperatorInitialization<GatherAttributes> = (node: Graph.Node): GatherAttributes =>
27
  createAttributeWithCacheKey({ axis: node.attributes.getInt('axis', 0) });
28

29
const gatherProgramMetadata = {
30
  name: 'Gather',
31
  inputNames: ['A', 'B'],
32
  inputTypes: [TextureType.unpacked, TextureType.unpacked],
33
};
34

35
const createGatherProgramInfo = (
36
  _handler: WebGLInferenceHandler,
37
  metadata: ProgramMetadata,
38
  inputs: Tensor[],
39
  axis: number,
40
): ProgramInfo => {
41
  const inputShape = inputs[0].dims.slice();
42
  const indexDataShape = inputs[1].dims.slice();
43
  const outputShape = new Array(inputShape.length + indexDataShape.length - 1);
44

45
  axis = ShapeUtil.normalizeAxis(axis, inputShape.length);
46
  const indexCopyOps: string[] = [];
47
  for (let i = 0; i < outputShape.length; i++) {
48
    // outputShape is divided into three parts: A, B, C
49
    // |0        axis|  axis + indexDataShape.length |          end|
50
    // |     A       |             B                 |      C      |
51
    //
52
    // inputIdx: [A, inputs[1][B], C]
53
    if (i < axis) {
54
      // A
55
      outputShape[i] = inputShape[i];
56
      indexCopyOps.push(`inputIdx[${i}] = outputIdx[${i}];`);
57
    } else {
58
      if (i < axis + indexDataShape.length) {
59
        // B
60
        outputShape[i] = indexDataShape[i - axis];
61
        indexCopyOps.push(`indexDataIdx[${i - axis}] = outputIdx[${i}];`);
62
      } else {
63
        // C
64
        outputShape[i] = inputShape[i - indexDataShape.length + 1]; // skip 1 for axis
65
        indexCopyOps.push(`inputIdx[${i - indexDataShape.length + 1}] = outputIdx[${i}];`);
66
      }
67
    }
68
  }
69

70
  const orank = outputShape.length || 1;
71
  const irank = inputShape.length;
72
  const iDrank = indexDataShape.length || 1;
73
  const shaderSource = `
74
      float process(int outputIdx[${orank}]) {
75
        int inputIdx[${irank}];
76
        int indexDataIdx[${iDrank}];
77
        indexDataIdx[0] = 0;
78
        ${indexCopyOps.join('\n        ')}
79
        int idx = int(_B(indexDataIdx));
80
        inputIdx[${axis}] = idx < 0 ? idx + ${inputShape[axis]} : idx;
81
        return _A(inputIdx);
82
      }`;
83
  return {
84
    ...metadata,
85
    output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked },
86
    shaderSource,
87
  };
88
};
89

90
const createGatherProgramInfoLoader = (
91
  handler: WebGLInferenceHandler,
92
  inputs: Tensor[],
93
  attributes: GatherAttributes,
94
): ProgramInfoLoader => {
95
  const metadata = { ...gatherProgramMetadata, cacheHint: attributes.cacheKey };
96
  return { ...metadata, get: () => createGatherProgramInfo(handler, metadata, inputs, attributes.axis) };
97
};
98

99
const validateInputs = (inputs: Tensor[], axis: number): void => {
100
  if (!inputs || inputs.length !== 2) {
101
    throw new Error('Gather requires 2 inputs.');
102
  }
103
  const tensorRank = inputs[0].dims.length;
104
  if (tensorRank < 1) {
105
    throw new Error('Invalid input shape.');
106
  }
107
  if (axis < -tensorRank || axis > tensorRank - 1) {
108
    throw new Error('Invalid axis.');
109
  }
110
  if (NUMBER_TYPES.indexOf(inputs[0].type) === -1) {
111
    throw new Error('Invaid input type.');
112
  }
113
  if (inputs[1].type !== 'int32' && inputs[1].type !== 'int16') {
114
    throw new Error('Invaid input type.');
115
  }
116
};
117

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.