1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
8
import { ComputeContext, ProgramInfo } from '../types';
11
createTensorShapeVariables,
19
type CoordinateTransformMode =
22
| 'pytorch_half_pixel'
23
| 'tf_half_pixel_for_nn'
25
| 'tf_crop_and_resize'
26
| 'half_pixel_symmetric';
28
type KeepAspectRatioPolicy = 'stretch' | 'not_smaller' | 'not_larger';
30
type Mode = 'nearest' | 'linear' | 'cubic';
32
type NearestMode = 'round_prefer_floor' | 'round_prefer_ceil' | 'floor' | 'ceil' | 'simple';
34
export interface ResizeAttributes extends AttributeWithCacheKey {
37
coordinateTransformMode: CoordinateTransformMode;
39
excludeOutside: boolean;
40
extrapolationValue: number;
41
keepAspectRatioPolicy: KeepAspectRatioPolicy;
43
nearestMode: NearestMode;
46
const validateScales = (scales: number[], attributes: ResizeAttributes): void => {
51
throw new Error('Resize requires scales input values to be positive');
54
// Check scales dims based on mode: LINEAR, CUBIC
55
if (scales.length > 0) {
56
if (attributes.mode === 'linear') {
59
scales.length === 2 ||
60
scales.length === 3 ||
61
(scales.length === 4 && scales[0] === 1 && scales[1] === 1) ||
62
(scales.length === 4 && scales[0] === 1 && scales[3] === 1) ||
63
(scales.length === 5 && scales[0] === 1 && scales[1] === 1)
67
`For linear mode, Resize requires scales to be 2D, 3D, 4D with either two outermost or one innermost and
68
one outermost scale values equal to 1, or 5D with two outermost scale values equal to 1`,
71
} else if (attributes.mode === 'cubic') {
74
scales.length === 2 ||
75
(scales.length === 4 && scales[0] === 1 && scales[1] === 1) ||
76
(scales.length === 4 && scales[0] === 1 && scales[3] === 1)
79
throw new Error('Resize requires scales input size to be 2 or 4 for cubic mode');
85
const updateScales = (scales: readonly number[], axes: readonly number[], rank: number): number[] => {
88
(value >= 0 && value < rank) ||
90
throw new Error('Resize requires axes input values to be positive and less than rank');
93
const newScales = new Array(rank).fill(1.0);
94
axes.forEach((value, index) => (newScales[value] = scales[index]));
98
const validateInputs = (
99
inputs: readonly TensorView[],
100
attributes: ResizeAttributes,
101
opsetVersion: number,
106
const [roiInputIndex, scalesInputIndex, sizesInputIndex] =
107
opsetVersion > 10 ? [1, 2, 3] : [-1, inputs.length > 1 ? 1 : -1, -1];
108
const rank = inputs[0].dims.length;
109
if (roiInputIndex > 0 && inputs.length > roiInputIndex && inputs[roiInputIndex].dims.length > 0) {
110
inputs[roiInputIndex].getFloat32Array().forEach((value) => roi.push(value));
111
} else if (attributes.coordinateTransformMode === 'tf_crop_and_resize') {
112
throw new Error('Resize requires RoI input to be specified when coordinateTransformMode is tfCropAndResize');
115
if (scalesInputIndex > 0 && inputs.length > scalesInputIndex && inputs[scalesInputIndex].dims.length > 0) {
116
inputs[scalesInputIndex].getFloat32Array().forEach((value) => scales.push(value));
118
scales.length !== 0 &&
119
scales.length !== rank &&
120
opsetVersion >= 18 &&
121
scales.length !== attributes.axes.length
123
throw new Error('Resize requires scales input size to be same as input rank or axes size for opset 18 and up');
125
validateScales(scales, attributes);
126
if (attributes.axes.length > 0) {
127
updateScales(scales, attributes.axes, rank).forEach((value, index) => (scales[index] = value));
130
if (sizesInputIndex > 0 && inputs.length > sizesInputIndex) {
131
inputs[sizesInputIndex].getBigInt64Array().forEach((value) => sizes.push(Number(value)));
132
if (sizes.length !== rank || (opsetVersion >= 18 && sizes.length === attributes.axes.length)) {
133
throw new Error('Resize requires sizes input size to be same as input rank or axes size for opset 18 and up');
137
if (attributes.axes.length > 0) {
138
if (scales.length !== attributes.axes.length) {
139
throw new Error('Resize requires "scales" input size to be of axes rank when axes attributes is specified');
141
if (sizes.length !== attributes.axes.length) {
142
throw new Error('Resize requires "sizes" input size to be of rank axes rank when axes attributes is specified');
145
if (typeof scales !== 'undefined' && typeof sizes !== 'undefined' && scales.length > 0 && sizes.length > rank) {
146
throw new Error('Resize requires only of scales or sizes to be specified');
150
const getOriginalCoordinateFromResizedCoordinate = (
151
coordinateTransferMode: CoordinateTransformMode,
154
`fn getOriginalCoordinateFromResizedCoordinate(xResized: u32, xScale: f32, lengthResized: u32,
155
lengthOriginal: u32, roiStart: f32, roiEnd: f32) -> ${dType} { ` +
157
switch (coordinateTransferMode) {
159
return `return ${dType}(xResized) / ${dType}(xScale);`;
160
case 'pytorch_half_pixel':
161
return `if (lengthResized > 1) {
162
return (${dType}(xResized) + 0.5) / ${dType}(xScale) - 0.5;
166
case 'tf_half_pixel_for_nn':
167
return `return (${dType}(xResized) + 0.5) / ${dType}(xScale);`;
168
case 'align_corners':
169
return `if (lengthResized == 1) {
172
// The whole part and the fractional part are calculated separately due to inaccuracy of floating
173
// point division. As an example, f32(21) / f32(7) may evaluate to 2.99... instead of 3, causing an
174
// offset-by-one error later in floor().
175
let whole = ${dType}(xResized * (lengthOriginal - 1) / (lengthResized - 1));
177
${dType}(xResized * (lengthOriginal - 1) % (lengthResized - 1)) / ${dType}(lengthResized - 1);
178
return whole + fract;
180
case 'tf_crop_and_resize':
181
return `if (lengthResized > 1) {
182
return ${dType}(roiStart) * ${dType}(lengthOriginal - 1) +
183
(${dType}(xResized) * ${dType}(roiEnd - roiStart) * ${dType}(lengthOriginal - 1)) /
184
${dType}(lengthResized - 1);
186
return 0.5 * ${dType}(roiStart + roiEnd) * ${dType}(lengthOriginal - 1);
188
case 'half_pixel_symmetric':
189
return `const outputWidth = ${dType}xScale * ${dType}(lengthResized);
190
const adjustment = ${dType}(lengthResized) / outputWidth;
191
const center = ${dType}(lengthOriginal) / 2;
192
const offset = center * (1 - adjustment);
193
return offset + ((${dType}(xResized) + 0.5) / ${dType}(xScale)) - 0.5;`;
195
return `return ((${dType}(xResized) + 0.5) / ${dType}(xScale)) - 0.5;`;
197
throw new Error(`Coordinate transform mode ${coordinateTransferMode} is not supported`);
202
const getNearestPixelFromOriginal = (nearestMode: NearestMode, opsetVersion: number, dType: string): string =>
203
`fn getNearestPixelFromOriginal(xOriginal: ${dType}, isDownSample: bool) -> ${dType} {` +
205
switch (nearestMode) {
206
case 'round_prefer_ceil':
207
return 'if (fract(xOriginal) == 0.5) { \
208
return ceil(xOriginal); \
210
return round(xOriginal); \
213
return 'return floor(xOriginal);';
215
return 'return ceil(xOriginal);';
216
case 'round_prefer_floor':
217
return 'if (fract(xOriginal) == 0.5) { \
218
return floor(xOriginal); \
220
return round(xOriginal); \
224
if (opsetVersion < 11) {
225
return 'if (isDownSample) \
227
return ceil(xOriginal); \
232
throw new Error(`Nearest mode ${nearestMode} is not supported`);
237
const updateRoI = (roi: readonly number[], axes: readonly number[], rank: number): number[] => {
238
const roiTmp = new Array(rank).fill(0).concat(new Array(rank).fill(1));
239
const roiLocal = roi.length === 0 ? roiTmp : roi.slice();
240
if (axes.length > 0) {
241
axes.forEach((v, i) => {
242
roiTmp[v] = roiLocal[i];
243
roiTmp[i + rank] = roiLocal[axes.length + i];
250
const initOutputShape = (
251
inputShape: readonly number[],
252
scales: readonly number[],
253
sizes: readonly number[],
254
axes: readonly number[],
256
let outputShape: number[] = [];
257
if (sizes.length > 0) {
258
if (axes.length > 0) {
259
inputShape.forEach((v) => outputShape.push(v));
260
if (Math.max(...axes) > inputShape.length) {
261
throw new Error('axes is out of bound');
263
axes.forEach((v, i) => (outputShape[v] = sizes[i]));
265
sizes.forEach((v) => outputShape.push(v));
268
if (scales.length === 0) {
269
throw new Error('Resize requires either scales or sizes.');
271
outputShape = inputShape.map((value, index) => Math.round(value * scales[index]));
277
const adjustOutputShape = (inputShape: readonly number[], scales: number[], attributes: ResizeAttributes) => {
278
const scaleInPolicy = (() => {
279
switch (attributes.keepAspectRatioPolicy) {
281
return attributes.axes.length > 0
282
? Math.min(...attributes.axes.map((i) => scales[i]), Number.MAX_VALUE)
283
: Math.min(...scales, Number.MAX_VALUE);
285
return attributes.axes.length > 0
286
? Math.max(...attributes.axes.map((i) => scales[i]), Number.MIN_VALUE)
287
: Math.max(...scales, Number.MIN_VALUE);
289
throw new Error(`Keep aspect ratio policy ${attributes.keepAspectRatioPolicy} is not supported`);
292
scales.fill(1.0, 0, scales.length);
293
const adjustedOutputShape = inputShape.slice();
294
if (attributes.axes.length > 0) {
295
attributes.axes.forEach((v) => (scales[v] = scaleInPolicy));
296
attributes.axes.forEach((v) => (adjustedOutputShape[v] = Math.round(inputShape[v] * scales[v])));
298
scales.fill(scaleInPolicy, 0, scales.length);
299
adjustedOutputShape.forEach((v, i) => (adjustedOutputShape[i] = Math.round(v * scales[i])));
301
return adjustedOutputShape;
304
const calculateOriginalIndicesFromOutputIndices = (
305
output: IndicesHelper,
306
inputShape: readonly number[],
307
outputShape: readonly number[],
308
scalesLength: number,
311
fn calculateOriginalIndicesFromOutputIndices(output_indices: ${output.type.indices}) -> array<${
313
}, ${outputShape.length}> {
314
var original_indices: array<${output.type.value}, ${outputShape.length}>;
315
for (var i:u32 = 0; i < ${outputShape.length}; i++) {
316
var output_index = ${output.indicesGet('output_indices', 'i')};
317
var scale = ${getElementAt('uniforms.scales', 'i', scalesLength)};
318
var roi_low = ${getElementAt('uniforms.roi', 'i', roiLength)};
319
var roi_hi = ${getElementAt('uniforms.roi', `i + ${inputShape.length}`, roiLength)};
321
original_indices[i] = ${output.type.value}(output_index);
323
var input_shape_i = ${getElementAt('uniforms.input_shape', 'i', inputShape.length)};
324
var output_shape_i = ${getElementAt('uniforms.output_shape', 'i', outputShape.length)};
325
original_indices[i] = getOriginalCoordinateFromResizedCoordinate(output_index, scale, output_shape_i,
326
input_shape_i, roi_low, roi_hi);
329
return original_indices;
332
const calculateInputIndicesFromOutputIndices = (
333
input: IndicesHelper,
334
output: IndicesHelper,
335
inputShape: readonly number[],
336
outputShape: readonly number[],
337
scalesLength: number,
339
useExtrapolation: boolean,
341
fn calculateInputIndicesFromOutputIndices(output_indices: ${output.type.indices}) -> ${input.type.indices} {
342
var input_indices: ${input.type.indices};
343
for (var i:u32 = 0; i < ${outputShape.length}; i++) {
344
var output_index = ${output.indicesGet('output_indices', 'i')};
345
var input_index: u32;
346
var scale = ${getElementAt('uniforms.scales', 'i', scalesLength)};
348
input_index = output_index;
350
var roi_low = ${getElementAt('uniforms.roi', 'i', roiLength)};
351
var roi_hi = ${getElementAt('uniforms.roi', `i + ${inputShape.length}`, roiLength)};
352
var input_shape_i = ${getElementAt('uniforms.input_shape', 'i', inputShape.length)};
353
var output_shape_i = ${getElementAt('uniforms.output_shape', 'i', outputShape.length)};
354
var original_idx = getOriginalCoordinateFromResizedCoordinate(output_index, scale, output_shape_i,
355
input_shape_i, roi_low, roi_hi);
356
if (!${useExtrapolation} || (original_idx >= 0 && original_idx < ${output.type.value}(input_shape_i))) {
357
if (original_idx < 0) {
359
} else if (original_idx > ${output.type.value}(input_shape_i - 1)) {
360
input_index = input_shape_i - 1;
362
input_index = u32(getNearestPixelFromOriginal(original_idx, scale < 1));
365
input_index = u32(original_idx);
368
${input.indicesSet('input_indices', 'i', ' input_index')}
370
return input_indices;
372
const checkInputIndices = (input: IndicesHelper, inputShape: readonly number[]): string => `
373
fn checkInputIndices(input_indices: ${input.type.indices}) -> bool {
374
for (var i:u32 = 0; i < ${inputShape.length}; i++) {
375
var input_index = ${input.indicesGet('input_indices', 'i')};
376
if (input_index < 0 || input_index >= ${getElementAt('uniforms.input_shape', 'i', inputShape.length)}) {
383
const setChannelAndBatchIndices = (
384
input: IndicesHelper,
389
input.rank > spacialDims
391
${input.indicesSet('input_indices', channelIdx, 'channel')};
392
${input.indicesSet('input_indices', batchIdx, 'batch')};
396
const bilinearInterpolation = (
397
input: IndicesHelper,
398
output: IndicesHelper,
399
inputShape: readonly number[],
400
useExtrapolation: boolean,
401
extrapolationValue: number,
404
const [batchIdx, heightIdx, widthIdx, channelIdx] =
405
inputShape.length === 2 ? [-1, 0, 1, -1] : isNchw ? [0, 2, 3, 1] : [0, 1, 2, 3];
406
const dType = input.type.value;
408
fn getInputValue(batch: u32, channel: u32, row: u32, col: u32) -> ${dType} {
409
var input_indices: ${input.type.indices};
410
${input.indicesSet('input_indices', heightIdx, `max(0, min(row, ${inputShape[heightIdx]} - 1))`)};
411
${input.indicesSet('input_indices', widthIdx, `max(0, min(col, ${inputShape[widthIdx]} - 1))`)};
412
${setChannelAndBatchIndices(input, channelIdx, batchIdx, 2)}
413
return ${input.getByIndices('input_indices')};
416
fn bilinearInterpolation(output_indices: ${output.type.indices}) -> ${dType} {
417
var originalIndices = calculateOriginalIndicesFromOutputIndices(output_indices);
418
var row:${dType} = originalIndices[${heightIdx}];
419
var col:${dType} = originalIndices[${widthIdx}];
422
? `if (row < 0 || row > (${inputShape[heightIdx]} - 1) || col < 0 || col > (${inputShape[widthIdx]} - 1)) {
423
return ${extrapolationValue};
427
row = max(0, min(row, ${inputShape[heightIdx]} - 1));
428
col = max(0, min(col, ${inputShape[widthIdx]} - 1));
429
var row1: u32 = u32(row);
430
var col1: u32 = u32(col);
431
var row2: u32 = u32(row + 1);
432
var col2: u32 = u32(col + 1);
433
var channel: u32 = ${inputShape.length > 2 ? `u32(originalIndices[${channelIdx}])` : '0'};
434
var batch: u32 = ${inputShape.length > 2 ? `u32(originalIndices[${batchIdx}])` : '0'};
435
var x11: ${dType} = getInputValue(batch, channel, row1, col1);
436
var x12: ${dType} = getInputValue(batch, channel, row1, col2);
437
var x21: ${dType} = getInputValue(batch, channel, row2, col1);
438
var x22: ${dType} = getInputValue(batch, channel, row2, col2);
439
var dx1: ${dType} = abs(row - ${dType}(row1));
440
var dx2: ${dType} = abs(${dType}(row2) - row);
441
var dy1: ${dType} = abs(col - ${dType}(col1));
442
var dy2: ${dType} = abs(${dType}(col2) - col);
451
return (x11 * dx2 * dy2 + x12 * dx2 * dy1 + x21 * dx1 * dy2 + x22 * dx1 * dy1);
455
const bicubicInterpolation = (
456
input: IndicesHelper,
457
output: IndicesHelper,
458
inputShape: readonly number[],
459
outputShape: readonly number[],
460
scales: readonly number[],
461
roi: readonly number[],
463
useExtrapolation: boolean,
464
extrapolationValue: number,
465
excludeOutside: boolean,
467
const is2D = inputShape.length === 2;
469
const [heightIdx, widthIdx] = is2D ? [0, 1] : isNchw ? [2, 3] : [1, 2];
470
const dType = input.type.value;
471
const createCubicInterpolationFunction = (idx: number): string => {
472
const direction = idx === heightIdx ? 'row' : 'col';
474
fn ${direction}CubicInterpolation(input_indices: ${input.type.indices}, output_indices: ${
477
var output_index = ${output.indicesGet('output_indices', idx)};
478
var originalIdx: ${dType} = getOriginalCoordinateFromResizedCoordinate(output_index, ${scales[idx]},
479
${outputShape[idx]}, ${inputShape[idx]}, ${roi[idx]}, ${roi[idx]} + ${inputShape.length});
480
var fractOriginalIdx: ${dType} = originalIdx - floor(originalIdx);
481
var coefs = getCubicInterpolationCoefs(fractOriginalIdx);
483
if (${useExtrapolation} && (originalIdx < 0 || originalIdx > (${inputShape[idx]} - 1))) {
484
return ${extrapolationValue};
486
var data: array<${dType}, 4> = array<${dType}, 4>(0.0, 0.0, 0.0, 0.0);
487
for (var i: i32 = -1; i < 3; i++) {
488
var ${direction}: ${dType} = originalIdx + ${dType}(i);
489
if (${direction} < 0 || ${direction} >= ${inputShape[idx]}) {
491
if (excludeOutside) {
492
return `coefs[i + 1] = 0.0;
494
} else if (useExtrapolation) {
495
return `return ${extrapolationValue};`;
497
return `${direction} = max(0, min(${direction}, ${inputShape[idx]} - 1));`;
501
var input_indices_copy: ${input.type.indices} = input_indices;
502
${input.indicesSet('input_indices_copy', idx, `u32(${direction})`)};
505
? input.getByIndices('input_indices_copy')
506
: 'rowCubicInterpolation(input_indices_copy, output_indices)'
509
return cubicInterpolation1D(data, coefs);
514
${createCubicInterpolationFunction(heightIdx)};
515
${createCubicInterpolationFunction(widthIdx)};
516
fn getCubicInterpolationCoefs(s: ${dType}) -> array<${dType}, 4> {
518
var coeffs: array<${dType}, 4> = array<${dType}, 4>(0.0, 0.0, 0.0, 0.0);
519
var oneMinusAbsS: ${dType} = 1.0 - absS;
520
var twoMinusAbsS: ${dType} = 2.0 - absS;
521
var onePlusAbsS: ${dType} = 1.0 + absS;
522
coeffs[0] = ((${cubicCoeffA} * onePlusAbsS - 5 * ${cubicCoeffA}) * onePlusAbsS + 8 * ${
524
}) * onePlusAbsS - 4 * ${cubicCoeffA};
525
coeffs[1] = ((${cubicCoeffA} + 2) * absS - (${cubicCoeffA} + 3)) * absS * absS + 1;
526
coeffs[2] = ((${cubicCoeffA} + 2) * oneMinusAbsS - (${cubicCoeffA} + 3)) * oneMinusAbsS * oneMinusAbsS + 1;
527
coeffs[3] = ((${cubicCoeffA} * twoMinusAbsS - 5 * ${cubicCoeffA}) * twoMinusAbsS + 8 * ${
529
}) * twoMinusAbsS - 4 * ${cubicCoeffA};
533
fn cubicInterpolation1D(x: array<${dType}, 4>, coefs: array<${dType}, 4>) -> ${dType} {
534
var coefsSum: ${dType} = coefs[0] + coefs[1] + coefs[2] + coefs[3];
535
return (x[0] * coefs[0] + x[1] * coefs[1]+ x[2] * coefs[2]+ x[3] * coefs[3]) / coefsSum;
538
fn bicubicInterpolation(output_indices: ${output.type.indices}) -> ${dType} {
539
var input_indices: ${input.type.indices} = output_indices;
540
return colCubicInterpolation(input_indices, output_indices);
545
const trilinearInterpolation = (
546
input: IndicesHelper,
547
output: IndicesHelper,
548
inputShape: readonly number[],
549
useExtrapolation: boolean,
550
extrapolationValue: number,
553
const [batchIdx, depthIdx, heightIdx, widthIdx, channelIdx] =
554
inputShape.length === 3 ? [-1, 0, 1, 2, -1] : isNchw ? [0, 2, 3, 4, 1] : [0, 1, 2, 3, 4];
555
const dType = input.type.value;
557
fn getInputValue(batch: u32, channel: u32, depth:u32, height: u32, width: u32) -> ${dType} {
558
var input_indices: ${input.type.indices};
559
${input.indicesSet('input_indices', depthIdx, `max(0, min(depth, ${inputShape[depthIdx]} - 1))`)};
560
${input.indicesSet('input_indices', heightIdx, `max(0, min(height, ${inputShape[heightIdx]} - 1))`)};
561
${input.indicesSet('input_indices', widthIdx, `max(0, min(width, ${inputShape[widthIdx]} - 1))`)};
562
${setChannelAndBatchIndices(input, channelIdx, batchIdx, 3)}
563
return ${input.getByIndices('input_indices')};
566
fn trilinearInterpolation(output_indices: ${output.type.indices}) -> ${dType} {
567
var originalIndices = calculateOriginalIndicesFromOutputIndices(output_indices);
568
var depth:${dType} = originalIndices[${depthIdx}];
569
var height:${dType} = originalIndices[${heightIdx}];
570
var width:${dType} = originalIndices[${widthIdx}];
573
? `if (depth < 0 || depth > (${inputShape[depthIdx]} - 1) || height < 0 || height > (${
574
inputShape[heightIdx]
575
} - 1) || width < 0 || (width > ${inputShape[widthIdx]} - 1)) {
576
return ${extrapolationValue};
581
depth = max(0, min(depth, ${inputShape[depthIdx]} - 1));
582
height = max(0, min(height, ${inputShape[heightIdx]} - 1));
583
width = max(0, min(width, ${inputShape[widthIdx]} - 1));
584
var depth1: u32 = u32(depth);
585
var height1: u32 = u32(height);
586
var width1: u32 = u32(width);
587
var depth2: u32 = u32(depth + 1);
588
var height2: u32 = u32(height + 1);
589
var width2: u32 = u32(width + 1);
590
var channel: u32 = ${inputShape.length > 3 ? `u32(originalIndices[${channelIdx}])` : '0'};
591
var batch: u32 = ${inputShape.length > 3 ? `u32(originalIndices[${batchIdx}])` : '0'};
593
var x111: ${dType} = getInputValue(batch, channel, depth1, height1, width1);
594
var x112: ${dType} = getInputValue(batch, channel, depth1, height1, width2);
595
var x121: ${dType} = getInputValue(batch, channel, depth1, height2, width1);
596
var x122: ${dType} = getInputValue(batch, channel, depth1, height2, width2);
597
var x211: ${dType} = getInputValue(batch, channel, depth2, height1, width1);
598
var x212: ${dType} = getInputValue(batch, channel, depth2, height1, width2);
599
var x221: ${dType} = getInputValue(batch, channel, depth2, height2, width1);
600
var x222: ${dType} = getInputValue(batch, channel, depth2, height2, width2);
601
var dx1: ${dType} = abs(depth - ${dType}(depth1));
602
var dx2: ${dType} = abs(${dType}(depth2) - depth);
603
var dy1: ${dType} = abs(height - ${dType}(height1));
604
var dy2: ${dType} = abs(${dType}(height2) - height);
605
var dz1: ${dType} = abs(width - ${dType}(width1));
606
var dz2: ${dType} = abs(${dType}(width2) - width);
607
if (depth1 == depth2) {
611
if (height1 == height2) {
615
if (width1 == width2) {
619
return (x111 * dx2 * dy2 * dz2 + x112 * dx2 * dy2 * dz1 + x121 * dx2 * dy1 *dz2 + x122 * dx2 * dy1 * dz1 +
620
x211 * dx1 * dy2 * dz2 + x212 * dx1 * dy2 * dz1 + x221 * dx1 * dy1 *dz2 + x222 * dx1 * dy1 * dz1);
624
const createResizeProgramInfo = (
625
inputTensor: TensorView,
626
attributes: ResizeAttributes,
627
opsetVersion: number,
628
scalesInput: readonly number[],
629
sizes: readonly number[],
630
roiInput: readonly number[],
632
const inputShape = inputTensor.dims;
633
const roi = updateRoI(roiInput, attributes.axes, inputShape.length);
635
let outputShape = initOutputShape(inputShape, scalesInput, sizes, attributes.axes);
636
let scales = scalesInput.slice();
637
if (scalesInput.length === 0) {
638
scales = inputShape.map((value, index) => (value === 0 ? 1.0 : outputShape[index] / value));
639
if (attributes.keepAspectRatioPolicy !== 'stretch') {
640
outputShape = adjustOutputShape(inputShape, scales, attributes);
643
const output = outputVariable('output', inputTensor.dataType, outputShape.length);
644
const input = inputVariable('input', inputTensor.dataType, inputShape.length);
645
const outputSize = ShapeUtil.size(outputShape);
646
const noScale = inputShape.length === outputShape.length && inputShape.every((d, i) => d === outputShape[i]);
647
const useExtrapolation = attributes.coordinateTransformMode === 'tf_crop_and_resize';
648
const extrapolationValue = attributes.extrapolationValue;
649
const dataType = input.type.value;
650
const getShaderSource = (shaderHelper: ShaderHelper) => `
655
${getOriginalCoordinateFromResizedCoordinate(attributes.coordinateTransformMode, dataType)};
657
switch (attributes.mode) {
660
${checkInputIndices(input, inputShape)};
661
${getNearestPixelFromOriginal(attributes.nearestMode, opsetVersion, dataType)};
662
${calculateInputIndicesFromOutputIndices(
674
${calculateOriginalIndicesFromOutputIndices(output, inputShape, outputShape, scales.length, roi.length)};
676
if (inputShape.length === 2 || inputShape.length === 4) {
677
return `${bilinearInterpolation(input, output, inputShape, useExtrapolation, extrapolationValue)}`;
678
} else if (inputShape.length === 3 || inputShape.length === 5) {
679
return `${trilinearInterpolation(input, output, inputShape, useExtrapolation, extrapolationValue)}`;
681
throw Error('Linear mode only supports input dims 2, 3, 4 and 5 are supported in linear mode.');
688
if (inputShape.length === 2 || inputShape.length === 4) {
689
return `${bicubicInterpolation(
696
attributes.cubicCoeffA,
698
attributes.extrapolationValue,
699
attributes.excludeOutside,
702
throw Error('Cubic mode only supports input dims 2 and 4 are supported in linear mode.');
707
throw Error('Invalid resize mode');
713
.registerUniform('output_size', 'u32')
714
.registerUniform('scales', 'f32', scales.length)
715
.registerUniform('roi', 'f32', roi.length)
716
.declareVariables(input, output)}
717
${shaderHelper.mainStart()}
718
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
721
? 'output[global_idx] = input[global_idx];'
723
let output_indices = ${output.offsetToIndices('global_idx')};
724
var input_indices: ${input.type.indices};
726
switch (attributes.mode) {
728
return `input_indices = calculateInputIndicesFromOutputIndices(output_indices);
729
if (checkInputIndices(input_indices)) {
730
output[global_idx] = ${input.getByIndices('input_indices')};
732
output[global_idx] = ${attributes.extrapolationValue};
735
return `output[global_idx] = ${
736
inputShape.length === 2 || inputShape.length === 4 ? 'bilinearInterpolation' : 'trilinearInterpolation'
739
return 'output[global_idx] = bicubicInterpolation(output_indices);';
741
throw Error(`Unsupported resize mode: ${attributes.mode}`);
751
hint: `${attributes.cacheKey}|${opsetVersion}|${scales.length > 0 ? scales : ''}|${
752
sizes.length > 0 ? sizes : ''
753
}|${roi.length > 0 ? roi : ''}|${noScale}|${inputShape}`,
754
inputDependencies: ['rank'],
758
outputs: [{ dims: outputShape, dataType: inputTensor.dataType }],
759
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
761
{ type: DataType.uint32, data: outputSize },
762
{ type: DataType.float, data: scales },
763
{ type: DataType.float, data: roi },
764
...createTensorShapeVariables(inputShape, outputShape),
770
const getOpsetVersionFromCustomDataBuffer = (context: ComputeContext): number => {
771
const customDataBuffer = context.customDataBuffer;
772
const customDataBuffer32 = new Uint32Array(customDataBuffer, customDataBuffer.byteOffset, 1);
773
const opsetVersion = customDataBuffer32[0];
777
export const resize = (context: ComputeContext, attributes: ResizeAttributes): void => {
778
const scales: number[] = [];
779
const sizes: number[] = [];
780
const roi: number[] = [];
782
// Note that scales in resize are always f32. roi can be f32 or f16.
783
// TODO: Currently this code does not support f16 for roi when passed as optional input.
785
const opsetVersion = getOpsetVersionFromCustomDataBuffer(context);
786
if (attributes.antialias !== 0) {
787
throw Error('Only default value (0) for Antialias attribute is supported');
789
validateInputs(context.inputs, attributes, opsetVersion, scales, sizes, roi);
790
context.compute(createResizeProgramInfo(context.inputs[0], attributes, opsetVersion, scales, sizes, roi), {
795
export const parseResizeAttributes = (attributes: Record<string, unknown>): ResizeAttributes => {
796
const antialias = attributes.antialias as number;
797
const axes = attributes.axes as number[];
798
const coordinateTransformMode: CoordinateTransformMode =
799
attributes.coordinateTransformMode as CoordinateTransformMode;
800
const cubicCoeffA = attributes.cubicCoeffA as number;
801
const excludeOutside = (attributes.excludeOutside as number) !== 0;
802
const extrapolationValue = attributes.extrapolationValue as number;
803
const keepAspectRatioPolicy: KeepAspectRatioPolicy = attributes.keepAspectRatioPolicy as KeepAspectRatioPolicy;
804
const mode: Mode = attributes.mode as Mode;
805
// If nearestMode is not specified, use simple mode.
806
const nearestMode: NearestMode = (attributes.nearestMode === '' ? 'simple' : attributes.nearestMode) as NearestMode;
807
return createAttributeWithCacheKey({
810
coordinateTransformMode,
814
keepAspectRatioPolicy,