onnxruntime
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4import { Graph } from '../../../graph';
5import { OperatorImplementation, OperatorInitialization } from '../../../operators';
6import { Tensor } from '../../../tensor';
7import { ProtoUtil } from '../../../util';
8import { WebGLInferenceHandler } from '../inference-handler';
9
10export const cast: OperatorImplementation<Tensor.DataType> = (
11handler: WebGLInferenceHandler,
12inputs: Tensor[],
13to: Tensor.DataType,
14): Tensor[] => {
15validateInputs(inputs);
16return [handler.cast(inputs[0], to)];
17};
18
19export const parseCastAttributes: OperatorInitialization<Tensor.DataType> = (node: Graph.Node): Tensor.DataType =>
20ProtoUtil.tensorDataTypeFromProto(node.attributes.getInt('to'));
21
22const validateInputs = (inputs: Tensor[]): void => {
23if (!inputs || inputs.length !== 1) {
24throw new Error('Cast requires 1 input.');
25}
26
27if (inputs[0].type === 'string') {
28throw new Error('Invalid input type.');
29}
30};
31