onnxruntime

Форк
0
30 строк · 1013.0 Байт
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { Graph } from '../../../graph';
5
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
6
import { Tensor } from '../../../tensor';
7
import { ProtoUtil } from '../../../util';
8
import { WebGLInferenceHandler } from '../inference-handler';
9

10
export const cast: OperatorImplementation<Tensor.DataType> = (
11
  handler: WebGLInferenceHandler,
12
  inputs: Tensor[],
13
  to: Tensor.DataType,
14
): Tensor[] => {
15
  validateInputs(inputs);
16
  return [handler.cast(inputs[0], to)];
17
};
18

19
export const parseCastAttributes: OperatorInitialization<Tensor.DataType> = (node: Graph.Node): Tensor.DataType =>
20
  ProtoUtil.tensorDataTypeFromProto(node.attributes.getInt('to'));
21

22
const validateInputs = (inputs: Tensor[]): void => {
23
  if (!inputs || inputs.length !== 1) {
24
    throw new Error('Cast requires 1 input.');
25
  }
26

27
  if (inputs[0].type === 'string') {
28
    throw new Error('Invalid input type.');
29
  }
30
};
31

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.