onnxruntime

Форк
0
92 строки · 3.0 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { Graph } from '../../../graph';
5
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
6
import { Tensor } from '../../../tensor';
7
import { WebGLInferenceHandler } from '../inference-handler';
8

9
import { transpose, TransposeAttributes } from './transpose';
10

11
export interface DepthToSpaceAttributes {
12
  mode: 'DCR' | 'CRD';
13
  blocksize: number;
14
}
15

16
export const depthToSpace: OperatorImplementation<DepthToSpaceAttributes> = (
17
  inferenceHandler: WebGLInferenceHandler,
18
  inputs: Tensor[],
19
  attributes: DepthToSpaceAttributes,
20
): Tensor[] => {
21
  validateInputs(inputs);
22
  const blocksize = attributes.blocksize;
23
  const blocksizeSqr = blocksize * blocksize;
24
  const transposePerm = attributes.mode === 'DCR' ? [0, 3, 4, 1, 5, 2] : [0, 1, 4, 2, 5, 3];
25
  const firstReshapeShape =
26
    attributes.mode === 'DCR'
27
      ? [
28
          inputs[0].dims[0],
29
          blocksize,
30
          blocksize,
31
          inputs[0].dims[1] / blocksizeSqr,
32
          inputs[0].dims[2],
33
          inputs[0].dims[3],
34
        ]
35
      : [
36
          inputs[0].dims[0],
37
          inputs[0].dims[1] / blocksizeSqr,
38
          blocksize,
39
          blocksize,
40
          inputs[0].dims[2],
41
          inputs[0].dims[3],
42
        ];
43

44
  // const transpose = new WebGLTranspose();
45
  // const attributes = new Attribute(undefined);
46
  // attributes.set('perm', 'ints', transposePerm);
47
  // transpose.initialize(attributes);
48

49
  // First reshape
50
  const firstReshapedTensor = inferenceHandler.reshapeUnpacked(inputs[0], firstReshapeShape);
51

52
  // transpose
53
  const transposeAttributes: TransposeAttributes = { perm: transposePerm, cacheKey: `${transposePerm}` };
54
  const [transposeOutput] = transpose(inferenceHandler, [firstReshapedTensor], transposeAttributes);
55

56
  // Second reshape
57
  const secondReshapeShape = [
58
    inputs[0].dims[0],
59
    inputs[0].dims[1] / blocksizeSqr,
60
    inputs[0].dims[2] * blocksize,
61
    inputs[0].dims[3] * blocksize,
62
  ];
63
  const result = inferenceHandler.reshapeUnpacked(transposeOutput, secondReshapeShape);
64
  return [result];
65
};
66

67
export const parseDepthToSpaceAttributes: OperatorInitialization<DepthToSpaceAttributes> = (
68
  node: Graph.Node,
69
): DepthToSpaceAttributes => {
70
  // processing node attributes
71
  const blocksize = node.attributes.getInt('blocksize');
72
  if (blocksize < 1) {
73
    throw new Error(`blocksize must be >= 1, but got : ${blocksize} for DepthToSpace`);
74
  }
75
  const mode = node.attributes.getString('mode', 'DCR');
76
  if (mode !== 'DCR' && mode !== 'CRD') {
77
    throw new Error(`unrecognized mode: ${mode} for DepthToSpace`);
78
  }
79
  return { mode, blocksize };
80
};
81

82
const validateInputs = (inputs: Tensor[]): void => {
83
  if (inputs.length !== 1) {
84
    throw new Error(`DepthToSpace expect 1 inputs, but got ${inputs.length}`);
85
  }
86

87
  // Input has to be a 4-D tensor
88
  // TODO: Support string depth-to-space.
89
  if (inputs[0].type === 'string' || inputs[0].dims.length !== 4) {
90
    throw new TypeError('DepthToSpace input should be a 4-D numeric tensor');
91
  }
92
};
93

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.