onnxruntime

Форк
0
42 строки · 1.3 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { Graph } from '../../../graph';
5
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
6
import { Tensor } from '../../../tensor';
7
import { ShapeUtil } from '../../../util';
8
import { WebGLInferenceHandler } from '../inference-handler';
9

10
export const flatten: OperatorImplementation<number> = (
11
  inferenceHandler: WebGLInferenceHandler,
12
  inputs: Tensor[],
13
  axis: number,
14
): Tensor[] => {
15
  validateInputs(inputs, axis);
16

17
  const outputDims = ShapeUtil.flattenShape(inputs[0].dims, axis);
18
  return [inferenceHandler.reshapeUnpacked(inputs[0], outputDims)];
19
};
20

21
export const parseFlattenAttributes: OperatorInitialization<number> = (node: Graph.Node): number =>
22
  node.attributes.getInt('axis', 1); // default axis is 1
23

24
const validateInputs = (inputs: Tensor[], axis: number): void => {
25
  if (!inputs || inputs.length !== 1) {
26
    throw new Error('Flatten requires 1 input.');
27
  }
28

29
  const r = inputs[0].dims.length;
30
  if (r === 0) {
31
    throw new Error('scalar tensor is not supported.');
32
  }
33

34
  if (axis < -r || axis > r) {
35
    throw new Error('Invalid axis');
36
  }
37

38
  // TODO: Support string type
39
  if (inputs[0].type === 'string') {
40
    throw new Error('string tensor is not supported.');
41
  }
42
};
43

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.