onnxruntime

Форк
0
110 строк · 3.3 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
5
import { Graph } from '../../../graph';
6
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
7
import { Tensor } from '../../../tensor';
8
import { ShapeUtil, SplitUtil } from '../../../util';
9
import { WebGLInferenceHandler } from '../inference-handler';
10
import { ProgramInfo, TextureType } from '../types';
11

12
export interface SplitAttributes extends AttributeWithCacheKey {
13
  readonly axis: number;
14
  readonly split: number[];
15
  readonly numOutputs: number;
16
}
17

18
const splitProgramMetadata = {
19
  name: 'Split',
20
  inputNames: ['A'],
21
  inputTypes: [TextureType.unpacked],
22
};
23

24
export const split: OperatorImplementation<SplitAttributes> = (
25
  inferenceHandler: WebGLInferenceHandler,
26
  inputs: Tensor[],
27
  attributes: SplitAttributes,
28
): Tensor[] => {
29
  validateInputs(inputs);
30

31
  const axis = ShapeUtil.normalizeAxis(attributes.axis, inputs[0].dims.length);
32
  const count = getProgramCount(inferenceHandler, inputs, axis, attributes);
33
  const output: Tensor[] = [];
34
  for (let i = 0; i < count; ++i) {
35
    output.push(
36
      inferenceHandler.run(
37
        {
38
          ...splitProgramMetadata,
39
          cacheHint: `${attributes.cacheKey};${i}`,
40
          get: () => createSplitProgramInfo(inferenceHandler, inputs[0], attributes, axis, i),
41
        },
42
        inputs,
43
      ),
44
    );
45
  }
46

47
  return output;
48
};
49

50
export const parseSplitAttributes: OperatorInitialization<SplitAttributes> = (node: Graph.Node): SplitAttributes => {
51
  const axis = node.attributes.getInt('axis', 0);
52
  const split = node.attributes.getInts('split', []);
53
  const numOutputs = node.outputs.length;
54
  return createAttributeWithCacheKey({ axis, split, numOutputs });
55
};
56

57
const getProgramCount = (
58
  _inferenceHandler: WebGLInferenceHandler,
59
  inputs: Tensor[],
60
  axis: number,
61
  attributes: SplitAttributes,
62
): number => {
63
  const [, offsets] = SplitUtil.splitShape(inputs[0].dims, axis, attributes.split, attributes.numOutputs);
64
  return offsets.length;
65
};
66

67
const createSplitProgramInfo = (
68
  _inferenceHandler: WebGLInferenceHandler,
69
  input: Tensor,
70
  attributes: SplitAttributes,
71
  axis: number,
72
  index: number,
73
): ProgramInfo => {
74
  const [shapes, offsets] = SplitUtil.splitShape(input.dims, axis, attributes.split, attributes.numOutputs);
75
  const offset = offsets[index];
76
  const outputShape = shapes[index];
77
  const rank = outputShape.length;
78
  const shaderSource = `
79
      float process(int indices[${rank}]) {
80
        indices[${axis}] += ${offset};
81
        return _A(indices);
82
      }
83
    `;
84
  return {
85
    ...splitProgramMetadata,
86
    cacheHint: `${attributes.cacheKey}:${index}`,
87
    output: { dims: outputShape, type: input.type, textureType: TextureType.unpacked },
88
    shaderSource,
89
  };
90
};
91

92
const validateInputs = (inputs: Tensor[]): void => {
93
  if (!inputs || inputs.length !== 1) {
94
    throw new Error('Split requires one input.');
95
  }
96

97
  if (
98
    inputs[0].type !== 'int8' &&
99
    inputs[0].type !== 'uint8' &&
100
    inputs[0].type !== 'int16' &&
101
    inputs[0].type !== 'uint16' &&
102
    inputs[0].type !== 'int32' &&
103
    inputs[0].type !== 'uint32' &&
104
    inputs[0].type !== 'float32' &&
105
    inputs[0].type !== 'float64' &&
106
    inputs[0].type !== 'bool'
107
  ) {
108
    throw new Error('Invalid input type.');
109
  }
110
};
111

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.