1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
5
import { Graph } from '../../../graph';
6
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
7
import { Tensor } from '../../../tensor';
8
import { WebGLInferenceHandler } from '../inference-handler';
9
import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types';
11
export interface ImageScalerAttributes extends AttributeWithCacheKey {
16
export const imageScaler: OperatorImplementation<ImageScalerAttributes> = (
17
inferenceHandler: WebGLInferenceHandler,
19
attributes: ImageScalerAttributes,
21
validateInputs(inputs);
22
const output = inferenceHandler.run(createImageScalerProgramInfoLoader(inferenceHandler, inputs, attributes), inputs);
26
export const parseImageScalerAttributes: OperatorInitialization<ImageScalerAttributes> = (
28
): ImageScalerAttributes => {
29
const scale = node.attributes.getFloat('scale');
30
const bias = node.attributes.getFloats('bias');
31
return createAttributeWithCacheKey({ scale, bias });
34
const imageScalerProgramMetadata = {
37
inputTypes: [TextureType.unpacked],
40
const createImageScalerProgramInfo = (
41
_handler: WebGLInferenceHandler,
42
metadata: ProgramMetadata,
44
attributes: ImageScalerAttributes,
46
const outputShape = inputs[0].dims.slice();
47
const rank = outputShape.length;
48
const getBiasMethod = createGetBiasMethod(attributes.bias.length);
49
const shaderSource = `
51
float process(int indices[${rank}]) {
52
return _X(indices) * scale + getBias(bias, indices[1]);
56
output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked },
58
{ name: 'bias', type: 'float', arrayLength: attributes.bias.length, data: attributes.bias },
59
{ name: 'scale', type: 'float', data: attributes.scale },
65
const createImageScalerProgramInfoLoader = (
66
handler: WebGLInferenceHandler,
68
attributes: ImageScalerAttributes,
69
): ProgramInfoLoader => {
70
const metadata = { ...imageScalerProgramMetadata, cacheHint: attributes.cacheKey };
71
return { ...metadata, get: () => createImageScalerProgramInfo(handler, metadata, inputs, attributes) };
74
const createGetBiasMethod = (numChannels: number): string => {
75
const codeLines: string[] = [`float getBias(float bias[${numChannels}], int channel) {`];
76
for (let i = 0; i < numChannels; ++i) {
78
codeLines.push('\t' + `if (channel == ${i}) { return bias[${i}]; }`);
79
} else if (i === numChannels - 1) {
80
codeLines.push('\t' + `else { return bias[${i}]; }`);
82
codeLines.push('\t' + `else if (channel == ${i}) { return bias[${i}]; }`);
85
codeLines.push('\t' + '}');
86
return codeLines.join('\n');
89
const validateInputs = (inputs: Tensor[]): void => {
90
if (!inputs || inputs.length !== 1) {
91
throw new Error('ImageScaler requires 1 input.');
93
if (inputs[0].dims.length !== 4) {
94
throw new Error('Invalid input shape.');
96
if (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') {
97
throw new Error('Invalid input type.');