1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { Graph } from '../../../graph';
5
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
6
import { Tensor } from '../../../tensor';
7
import { getGlsl } from '../glsl-source';
8
import { WebGLInferenceHandler } from '../inference-handler';
9
import { ProgramInfo, TextureType } from '../types';
10
import { getCoordsDataType } from '../utils';
12
import { unpackFromChannel } from './packing-utils';
13
import { parseUpsampleAttributes, scalesValidation, UpsampleAttributes, validateInputs } from './upsample';
15
const resizeProgramMetadata = {
18
inputTypes: [TextureType.packed],
21
export const resize: OperatorImplementation<UpsampleAttributes> = (
22
inferenceHandler: WebGLInferenceHandler,
24
attributes: UpsampleAttributes,
26
validateInputs(inputs, attributes);
27
const output = inferenceHandler.run(
29
...resizeProgramMetadata,
30
cacheHint: attributes.cacheKey,
31
get: () => createPackedResizeProgramInfo(inferenceHandler, inputs, attributes),
38
export const parseResizeAttributesV10: OperatorInitialization<UpsampleAttributes> = (
40
): UpsampleAttributes => parseUpsampleAttributes(node, 10);
42
export const parseResizeAttributesV11: OperatorInitialization<UpsampleAttributes> = (
44
): UpsampleAttributes => parseUpsampleAttributes(node, 11);
46
const createPackedResizeProgramInfo = (
47
inferenceHandler: WebGLInferenceHandler,
49
attributes: UpsampleAttributes,
51
const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
52
const [scales, outputShape] = prepareInputs(inputs, attributes);
54
const isSame = scales.every((s: number) => s === 1) && attributes.coordinateTransformMode !== 'tf_crop_and_resize';
57
...resizeProgramMetadata,
58
output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.packed },
60
shaderSource: `void main() {
61
vec4 v = ${glsl.texture2D}(X, TexCoords);
67
const dim = outputShape.length;
69
throw new Error(`output dimension should be at least 2, but got ${dim}`);
72
const outputHeight = outputShape[dim - 2];
73
const outputWidth = outputShape[dim - 1];
75
const inputShape = inputs[0].dims;
76
if (dim !== inputShape.length) {
77
throw new Error(`output dimension should match input ${inputShape.length}, but got ${dim}`);
79
const inputHeight = inputShape[dim - 2];
80
const inputWidth = inputShape[dim - 1];
82
const scalesHeight = scales[dim - 2];
83
const scalesWidth = scales[dim - 1];
85
let getSourceFracIndex = '';
87
if (attributes.mode !== 'linear') {
88
// TODO: support other modes
89
throw new Error(`resize (packed) does not support mode: '${attributes.mode}'`);
91
switch (attributes.coordinateTransformMode) {
93
getSourceFracIndex = `
94
vec4 getSourceFracIndex(ivec4 coords) {
95
return vec4(coords) / scaleWHWH;
100
getSourceFracIndex = `
101
vec4 getSourceFracIndex(ivec4 coords) {
102
return (vec4(coords) + 0.5) / scaleWHWH - 0.5;
106
case 'pytorch_half_pixel':
107
getSourceFracIndex = `
108
vec4 getSourceFracIndex(ivec4 coords) {
109
vec4 fcoords = vec4(coords);
111
${outputWidth}.0 > 1.0 ? (fcoords.x + 0.5) / scaleWHWH.x - 0.5 : 0.0,
112
${outputHeight}.0 > 1.0 ? (fcoords.y + 0.5) / scaleWHWH.y - 0.5 : 0.0,
113
${outputWidth}.0 > 1.0 ? (fcoords.z + 0.5) / scaleWHWH.z - 0.5 : 0.0,
114
${outputHeight}.0 > 1.0 ? (fcoords.w + 0.5) / scaleWHWH.w - 0.5 : 0.0
119
case 'align_corners':
120
getSourceFracIndex = `
121
vec4 getSourceFracIndex(ivec4 coords) {
122
vec4 resized = vec4(${outputWidth}.0 - 1.0, ${outputHeight}.0 - 1.0, ${outputWidth}.0 - 1.0,
123
${outputHeight}.0 - 1.0);
124
vec4 original = vec4(${inputWidth}.0 - 1.0, ${inputHeight}.0 - 1.0, ${inputWidth}.0 - 1.0,
125
${inputHeight}.0 - 1.0);
126
vec4 new_scale = original / resized;
127
return vec4(coords) * new_scale;
132
// TODO:supporting other coordinateTransformModes
133
throw new Error(`resize (packed) does not support coordinateTransformMode: \
134
'${attributes.coordinateTransformMode}'`);
137
const coordsDataType = getCoordsDataType(dim);
138
const unpackChannel = unpackFromChannel();
139
const shaderSource = `
140
const vec2 inputWH = vec2(${inputHeight}.0, ${inputWidth}.0);
141
const vec4 scaleWHWH = vec4(float(${scalesHeight}), float(${scalesWidth}), float(${scalesHeight}), float(${
145
${getSourceFracIndex}
146
float getAValue(int x10, int r, int c, int d) {
147
return getChannel(getA(x10, r, c, d), vec2(c, d));
150
${coordsDataType} rc = getOutputCoords();
155
// retrieve the 4 coordinates that is used in the 4 packed output values.
156
ivec4 coords = ivec4(rc.wz, rc.w + 1, rc.z + 1);
158
// calculate the source index in fraction
159
vec4 sourceFrac = getSourceFracIndex(coords);
161
// get the lower and upper bound of the 4 values that will be packed into one texel.
162
ivec4 x00 = ivec4(max(sourceFrac.xy, vec2(0.0)), min(inputWH - 1.0, ceil(sourceFrac.xy)));
163
ivec4 x01 = ivec4(max(sourceFrac.xw, vec2(0.0)), min(inputWH - 1.0, ceil(sourceFrac.xw)));
164
ivec4 x10 = ivec4(max(sourceFrac.zy, vec2(0.0)), min(inputWH - 1.0, ceil(sourceFrac.zy)));
165
ivec4 x11 = ivec4(max(sourceFrac.zw, vec2(0.0)), min(inputWH - 1.0, ceil(sourceFrac.zw)));
167
bool hasNextRow = rc.w < ${outputHeight - 1};
168
bool hasNextCol = rc.z < ${outputWidth - 1};
170
// pack x00, x01, x10, x11's top-left corner into one vec4 structure
172
getAValue(batch, depth, x00.x, x00.y),
173
hasNextCol ? getAValue(batch, depth, x01.x, x01.y) : 0.0,
174
hasNextRow ? getAValue(batch, depth, x10.x, x10.y) : 0.0,
175
(hasNextRow && hasNextCol) ? getAValue(batch, depth, x11.x, x11.y) : 0.0);
177
// pack x00, x01, x10, x11's top-right corner into one vec4 structure
178
vec4 topRight = vec4(
179
getAValue(batch, depth, x00.x, x00.w),
180
hasNextCol ? getAValue(batch, depth, x01.x, x01.w) : 0.0,
181
hasNextRow ? getAValue(batch, depth, x10.x, x10.w) : 0.0,
182
(hasNextRow && hasNextCol) ? getAValue(batch, depth, x11.x, x11.w) : 0.0);
184
// pack x00, x01, x10, x11's bottom-left corner into one vec4 structure
185
vec4 bottomLeft = vec4(
186
getAValue(batch, depth, x00.z, x00.y),
187
hasNextCol ? getAValue(batch, depth, x01.z, x01.y) : 0.0,
188
hasNextRow ? getAValue(batch, depth, x10.z, x10.y) : 0.0,
189
(hasNextRow && hasNextCol) ? getAValue(batch, depth, x11.z, x11.y) : 0.0);
191
// pack x00, x01, x10, x11's bottom-right corner into one vec4 structure
192
vec4 bottomRight = vec4(
193
getAValue(batch, depth, x00.z, x00.w),
194
hasNextCol ? getAValue(batch, depth, x01.z, x01.w) : 0.0,
195
hasNextRow ? getAValue(batch, depth, x10.z, x10.w) : 0.0,
196
(hasNextRow && hasNextCol) ? getAValue(batch, depth, x11.z, x11.w) : 0.0);
198
// calculate the interpolation fraction on u and v direction
199
vec4 frac = vec4(sourceFrac) - floor(sourceFrac);
200
vec4 clampFrac = clamp(frac, vec4(0.0), vec4(1.0));
202
vec4 top = mix(topLeft, topRight, clampFrac.ywyw);
203
vec4 bottom = mix(bottomLeft, bottomRight, clampFrac.ywyw);
204
vec4 newValue = mix(top, bottom, clampFrac.xxzz);
206
${glsl.output} = vec4(newValue);
210
...resizeProgramMetadata,
211
output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.packed },
217
const prepareInputs = (inputs: Tensor[], attributes: UpsampleAttributes): [readonly number[], readonly number[]] => {
219
const xDims = x.dims;
221
let scales = attributes.scales;
222
let outputSizes: number[] | undefined;
223
if (scales.length === 0) {
224
const scalesTensor = inputs[attributes.scalesInputIdx];
225
if (scalesTensor && scalesTensor.size !== 0) {
226
if (inputs[attributes.sizesInputIdx]) {
227
throw new Error('Only one of scales or sizes must be provided as input.');
229
scales = parseScalesData(scalesTensor, attributes.mode, attributes.isResize);
231
const sizesTensor = inputs[attributes.sizesInputIdx];
232
if (!sizesTensor || sizesTensor.size === 0) {
233
throw new Error('Either scales or sizes MUST be provided as input.');
236
outputSizes = Array.from(sizesTensor.integerData);
237
scales = parseScalesDataFromOutputSize(outputSizes, xDims, attributes.mode, attributes.isResize);
240
if (inputs[attributes.sizesInputIdx]) {
241
throw new Error('Only one of scales or sizes must be provided as input.');
245
const yDims = outputSizes || xDims.map((dim, i) => Math.floor(dim * scales[i]));
247
return [scales, yDims];
250
const parseScalesData = (scale: Tensor, mode: string, isResize: boolean): number[] => {
251
const scales = Array.from(scale.floatData);
252
scalesValidation(scales, mode, isResize);
256
const parseScalesDataFromOutputSize = (
257
yDims: readonly number[],
258
xDims: readonly number[],
262
const length = xDims.length;
263
const scales = new Array<number>(length);
265
for (let i = 0, end = length; i < end; i++) {
266
if (xDims[i] === 0) {
267
if (yDims[i] !== 0) {
268
throw new Error('Input dim is zero but required output dim is non-zero.');
272
scales[i] = yDims[i] / xDims[i];
275
scalesValidation(scales, mode, isResize);
279
// roi data is not used yet. but leave here for future usage.
280
// const getRoi = (inputs: Tensor[], attributes: UpsampleAttributes) : number[] => {
281
// let roi: number[] = [];
282
// if (attributes.needRoiInput) {
283
// if (attributes.roiInputIdx <= 0) {
284
// throw new Error('Invalid roi input index.');
286
// const roiTensor = inputs[attributes.roiInputIdx];
287
// roi = roiTensor.size > 0 ? Array.from(roiTensor.floatData) : [];
289
// roi = new Array(inputs[0].dims.length * 2).fill(0);