onnxruntime

Форк
0
818 строк · 33.0 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
8
import { ComputeContext, ProgramInfo } from '../types';
9

10
import {
11
  createTensorShapeVariables,
12
  getElementAt,
13
  IndicesHelper,
14
  inputVariable,
15
  outputVariable,
16
  ShaderHelper,
17
} from './common';
18

19
type CoordinateTransformMode =
20
  | 'half_pixel'
21
  | 'asymmetric'
22
  | 'pytorch_half_pixel'
23
  | 'tf_half_pixel_for_nn'
24
  | 'align_corners'
25
  | 'tf_crop_and_resize'
26
  | 'half_pixel_symmetric';
27

28
type KeepAspectRatioPolicy = 'stretch' | 'not_smaller' | 'not_larger';
29

30
type Mode = 'nearest' | 'linear' | 'cubic';
31

32
type NearestMode = 'round_prefer_floor' | 'round_prefer_ceil' | 'floor' | 'ceil' | 'simple';
33

34
export interface ResizeAttributes extends AttributeWithCacheKey {
35
  antialias: number;
36
  axes: number[];
37
  coordinateTransformMode: CoordinateTransformMode;
38
  cubicCoeffA: number;
39
  excludeOutside: boolean;
40
  extrapolationValue: number;
41
  keepAspectRatioPolicy: KeepAspectRatioPolicy;
42
  mode: Mode;
43
  nearestMode: NearestMode;
44
}
45

46
const validateScales = (scales: number[], attributes: ResizeAttributes): void => {
47
  scales.every(
48
    (value) =>
49
      value > 0 ||
50
      (() => {
51
        throw new Error('Resize requires scales input values to be positive');
52
      }),
53
  );
54
  // Check scales dims based on mode: LINEAR, CUBIC
55
  if (scales.length > 0) {
56
    if (attributes.mode === 'linear') {
57
      if (
58
        !(
59
          scales.length === 2 ||
60
          scales.length === 3 ||
61
          (scales.length === 4 && scales[0] === 1 && scales[1] === 1) ||
62
          (scales.length === 4 && scales[0] === 1 && scales[3] === 1) ||
63
          (scales.length === 5 && scales[0] === 1 && scales[1] === 1)
64
        )
65
      ) {
66
        throw new Error(
67
          `For linear mode, Resize requires scales to be 2D, 3D, 4D with either two outermost or one innermost and
68
            one outermost scale values equal to 1, or 5D with two outermost scale values equal to 1`,
69
        );
70
      }
71
    } else if (attributes.mode === 'cubic') {
72
      if (
73
        !(
74
          scales.length === 2 ||
75
          (scales.length === 4 && scales[0] === 1 && scales[1] === 1) ||
76
          (scales.length === 4 && scales[0] === 1 && scales[3] === 1)
77
        )
78
      ) {
79
        throw new Error('Resize requires scales input size to be 2 or 4 for cubic mode');
80
      }
81
    }
82
  }
83
};
84

85
const updateScales = (scales: readonly number[], axes: readonly number[], rank: number): number[] => {
86
  axes.every(
87
    (value) =>
88
      (value >= 0 && value < rank) ||
89
      (() => {
90
        throw new Error('Resize requires axes input values to be positive and less than rank');
91
      }),
92
  );
93
  const newScales = new Array(rank).fill(1.0);
94
  axes.forEach((value, index) => (newScales[value] = scales[index]));
95
  return newScales;
96
};
97

98
const validateInputs = (
99
  inputs: readonly TensorView[],
100
  attributes: ResizeAttributes,
101
  opsetVersion: number,
102
  scales: number[],
103
  sizes: number[],
104
  roi: number[],
105
): void => {
106
  const [roiInputIndex, scalesInputIndex, sizesInputIndex] =
107
    opsetVersion > 10 ? [1, 2, 3] : [-1, inputs.length > 1 ? 1 : -1, -1];
108
  const rank = inputs[0].dims.length;
109
  if (roiInputIndex > 0 && inputs.length > roiInputIndex && inputs[roiInputIndex].dims.length > 0) {
110
    inputs[roiInputIndex].getFloat32Array().forEach((value) => roi.push(value));
111
  } else if (attributes.coordinateTransformMode === 'tf_crop_and_resize') {
112
    throw new Error('Resize requires RoI input to be specified when coordinateTransformMode is tfCropAndResize');
113
  }
114

115
  if (scalesInputIndex > 0 && inputs.length > scalesInputIndex && inputs[scalesInputIndex].dims.length > 0) {
116
    inputs[scalesInputIndex].getFloat32Array().forEach((value) => scales.push(value));
117
    if (
118
      scales.length !== 0 &&
119
      scales.length !== rank &&
120
      opsetVersion >= 18 &&
121
      scales.length !== attributes.axes.length
122
    ) {
123
      throw new Error('Resize requires scales input size to be same as input rank or axes size for opset 18 and up');
124
    }
125
    validateScales(scales, attributes);
126
    if (attributes.axes.length > 0) {
127
      updateScales(scales, attributes.axes, rank).forEach((value, index) => (scales[index] = value));
128
    }
129
  }
130
  if (sizesInputIndex > 0 && inputs.length > sizesInputIndex) {
131
    inputs[sizesInputIndex].getBigInt64Array().forEach((value) => sizes.push(Number(value)));
132
    if (sizes.length !== rank || (opsetVersion >= 18 && sizes.length === attributes.axes.length)) {
133
      throw new Error('Resize requires sizes input size to be same as input rank or axes size for opset 18 and up');
134
    }
135
  }
136

137
  if (attributes.axes.length > 0) {
138
    if (scales.length !== attributes.axes.length) {
139
      throw new Error('Resize requires "scales" input size to be of axes rank when axes attributes is specified');
140
    }
141
    if (sizes.length !== attributes.axes.length) {
142
      throw new Error('Resize requires "sizes" input size to be of rank axes rank when axes attributes is specified');
143
    }
144
  }
145
  if (typeof scales !== 'undefined' && typeof sizes !== 'undefined' && scales.length > 0 && sizes.length > rank) {
146
    throw new Error('Resize requires only of scales or sizes to be specified');
147
  }
148
};
149

150
const getOriginalCoordinateFromResizedCoordinate = (
151
  coordinateTransferMode: CoordinateTransformMode,
152
  dType: string,
153
): string =>
154
  `fn getOriginalCoordinateFromResizedCoordinate(xResized: u32, xScale: f32, lengthResized: u32,
155
     lengthOriginal: u32, roiStart: f32, roiEnd: f32) -> ${dType} { ` +
156
  (() => {
157
    switch (coordinateTransferMode) {
158
      case 'asymmetric':
159
        return `return ${dType}(xResized) / ${dType}(xScale);`;
160
      case 'pytorch_half_pixel':
161
        return `if (lengthResized > 1) {
162
                    return (${dType}(xResized) + 0.5) / ${dType}(xScale) - 0.5;
163
                  } else {
164
                    return 0.0;
165
                  }`;
166
      case 'tf_half_pixel_for_nn':
167
        return `return (${dType}(xResized) + 0.5) / ${dType}(xScale);`;
168
      case 'align_corners':
169
        return `if (lengthResized == 1) {
170
                    return 0.0;
171
                  } else {
172
                    // The whole part and the fractional part are calculated separately due to inaccuracy of floating
173
                    // point division. As an example, f32(21) / f32(7) may evaluate to 2.99... instead of 3, causing an
174
                    // offset-by-one error later in floor().
175
                    let whole = ${dType}(xResized * (lengthOriginal - 1) / (lengthResized - 1));
176
                    let fract =
177
                        ${dType}(xResized * (lengthOriginal - 1) % (lengthResized - 1)) / ${dType}(lengthResized - 1);
178
                    return whole + fract;
179
                  }`;
180
      case 'tf_crop_and_resize':
181
        return `if (lengthResized > 1) {
182
                    return ${dType}(roiStart) * ${dType}(lengthOriginal - 1) +
183
                        (${dType}(xResized) * ${dType}(roiEnd - roiStart) * ${dType}(lengthOriginal - 1)) /
184
                        ${dType}(lengthResized - 1);
185
                  } else {
186
                    return 0.5 * ${dType}(roiStart + roiEnd) * ${dType}(lengthOriginal - 1);
187
                  }`;
188
      case 'half_pixel_symmetric':
189
        return `const outputWidth = ${dType}xScale * ${dType}(lengthResized);
190
                  const adjustment = ${dType}(lengthResized) / outputWidth;
191
                  const center = ${dType}(lengthOriginal) / 2;
192
                  const offset = center * (1 - adjustment);
193
                  return offset + ((${dType}(xResized) + 0.5) / ${dType}(xScale)) - 0.5;`;
194
      case 'half_pixel':
195
        return `return ((${dType}(xResized) + 0.5) / ${dType}(xScale)) - 0.5;`;
196
      default:
197
        throw new Error(`Coordinate transform mode ${coordinateTransferMode} is not supported`);
198
    }
199
  })() +
200
  '}';
201

202
const getNearestPixelFromOriginal = (nearestMode: NearestMode, opsetVersion: number, dType: string): string =>
203
  `fn getNearestPixelFromOriginal(xOriginal: ${dType}, isDownSample: bool) -> ${dType} {` +
204
  (() => {
205
    switch (nearestMode) {
206
      case 'round_prefer_ceil':
207
        return 'if (fract(xOriginal) == 0.5) { \
208
            return ceil(xOriginal); \
209
          } else { \
210
            return round(xOriginal); \
211
          }';
212
      case 'floor':
213
        return 'return floor(xOriginal);';
214
      case 'ceil':
215
        return 'return ceil(xOriginal);';
216
      case 'round_prefer_floor':
217
        return 'if (fract(xOriginal) == 0.5) { \
218
                    return floor(xOriginal); \
219
                  } else { \
220
                    return round(xOriginal); \
221
                  }';
222
      case 'simple':
223
      default:
224
        if (opsetVersion < 11) {
225
          return 'if (isDownSample) \
226
                    { \
227
                      return ceil(xOriginal); \
228
                    } else { \
229
                      return xOriginal; \
230
                    }';
231
        }
232
        throw new Error(`Nearest mode ${nearestMode} is not supported`);
233
    }
234
  })() +
235
  '}';
236

237
const updateRoI = (roi: readonly number[], axes: readonly number[], rank: number): number[] => {
238
  const roiTmp = new Array(rank).fill(0).concat(new Array(rank).fill(1));
239
  const roiLocal = roi.length === 0 ? roiTmp : roi.slice();
240
  if (axes.length > 0) {
241
    axes.forEach((v, i) => {
242
      roiTmp[v] = roiLocal[i];
243
      roiTmp[i + rank] = roiLocal[axes.length + i];
244
    });
245
    return roiTmp;
246
  }
247
  return roiLocal;
248
};
249

250
const initOutputShape = (
251
  inputShape: readonly number[],
252
  scales: readonly number[],
253
  sizes: readonly number[],
254
  axes: readonly number[],
255
): number[] => {
256
  let outputShape: number[] = [];
257
  if (sizes.length > 0) {
258
    if (axes.length > 0) {
259
      inputShape.forEach((v) => outputShape.push(v));
260
      if (Math.max(...axes) > inputShape.length) {
261
        throw new Error('axes is out of bound');
262
      }
263
      axes.forEach((v, i) => (outputShape[v] = sizes[i]));
264
    } else {
265
      sizes.forEach((v) => outputShape.push(v));
266
    }
267
  } else {
268
    if (scales.length === 0) {
269
      throw new Error('Resize requires either scales or sizes.');
270
    } else {
271
      outputShape = inputShape.map((value, index) => Math.round(value * scales[index]));
272
    }
273
  }
274
  return outputShape;
275
};
276

277
const adjustOutputShape = (inputShape: readonly number[], scales: number[], attributes: ResizeAttributes) => {
278
  const scaleInPolicy = (() => {
279
    switch (attributes.keepAspectRatioPolicy) {
280
      case 'not_larger':
281
        return attributes.axes.length > 0
282
          ? Math.min(...attributes.axes.map((i) => scales[i]), Number.MAX_VALUE)
283
          : Math.min(...scales, Number.MAX_VALUE);
284
      case 'not_smaller':
285
        return attributes.axes.length > 0
286
          ? Math.max(...attributes.axes.map((i) => scales[i]), Number.MIN_VALUE)
287
          : Math.max(...scales, Number.MIN_VALUE);
288
      default:
289
        throw new Error(`Keep aspect ratio policy ${attributes.keepAspectRatioPolicy} is not supported`);
290
    }
291
  })();
292
  scales.fill(1.0, 0, scales.length);
293
  const adjustedOutputShape = inputShape.slice();
294
  if (attributes.axes.length > 0) {
295
    attributes.axes.forEach((v) => (scales[v] = scaleInPolicy));
296
    attributes.axes.forEach((v) => (adjustedOutputShape[v] = Math.round(inputShape[v] * scales[v])));
297
  } else {
298
    scales.fill(scaleInPolicy, 0, scales.length);
299
    adjustedOutputShape.forEach((v, i) => (adjustedOutputShape[i] = Math.round(v * scales[i])));
300
  }
301
  return adjustedOutputShape;
302
};
303

304
const calculateOriginalIndicesFromOutputIndices = (
305
  output: IndicesHelper,
306
  inputShape: readonly number[],
307
  outputShape: readonly number[],
308
  scalesLength: number,
309
  roiLength: number,
310
): string => `
311
    fn calculateOriginalIndicesFromOutputIndices(output_indices: ${output.type.indices}) -> array<${
312
      output.type.value
313
    }, ${outputShape.length}> {
314
      var original_indices: array<${output.type.value}, ${outputShape.length}>;
315
      for (var i:u32 = 0; i < ${outputShape.length}; i++) {
316
        var output_index = ${output.indicesGet('output_indices', 'i')};
317
        var scale = ${getElementAt('uniforms.scales', 'i', scalesLength)};
318
        var roi_low = ${getElementAt('uniforms.roi', 'i', roiLength)};
319
        var roi_hi = ${getElementAt('uniforms.roi', `i + ${inputShape.length}`, roiLength)};
320
        if (scale == 1.0) {
321
          original_indices[i] = ${output.type.value}(output_index);
322
        } else {
323
          var input_shape_i = ${getElementAt('uniforms.input_shape', 'i', inputShape.length)};
324
          var output_shape_i = ${getElementAt('uniforms.output_shape', 'i', outputShape.length)};
325
          original_indices[i] = getOriginalCoordinateFromResizedCoordinate(output_index, scale, output_shape_i,
326
                                                                           input_shape_i, roi_low, roi_hi);
327
        }
328
      }
329
      return original_indices;
330
    }`;
331

332
const calculateInputIndicesFromOutputIndices = (
333
  input: IndicesHelper,
334
  output: IndicesHelper,
335
  inputShape: readonly number[],
336
  outputShape: readonly number[],
337
  scalesLength: number,
338
  roiLength: number,
339
  useExtrapolation: boolean,
340
): string => `
341
    fn calculateInputIndicesFromOutputIndices(output_indices: ${output.type.indices}) -> ${input.type.indices} {
342
      var input_indices: ${input.type.indices};
343
      for (var i:u32 = 0; i < ${outputShape.length}; i++) {
344
        var output_index = ${output.indicesGet('output_indices', 'i')};
345
        var input_index: u32;
346
        var scale = ${getElementAt('uniforms.scales', 'i', scalesLength)};
347
        if (scale == 1.0) {
348
          input_index = output_index;
349
        } else {
350
          var roi_low = ${getElementAt('uniforms.roi', 'i', roiLength)};
351
          var roi_hi = ${getElementAt('uniforms.roi', `i + ${inputShape.length}`, roiLength)};
352
          var input_shape_i = ${getElementAt('uniforms.input_shape', 'i', inputShape.length)};
353
          var output_shape_i = ${getElementAt('uniforms.output_shape', 'i', outputShape.length)};
354
          var original_idx = getOriginalCoordinateFromResizedCoordinate(output_index, scale, output_shape_i,
355
                                                                        input_shape_i, roi_low, roi_hi);
356
          if (!${useExtrapolation} || (original_idx >= 0 && original_idx < ${output.type.value}(input_shape_i))) {
357
            if (original_idx < 0) {
358
              input_index = 0;
359
            } else if (original_idx > ${output.type.value}(input_shape_i - 1)) {
360
              input_index = input_shape_i - 1;
361
            } else {
362
              input_index = u32(getNearestPixelFromOriginal(original_idx, scale < 1));
363
            }
364
          } else {
365
            input_index = u32(original_idx);
366
          }
367
        }
368
        ${input.indicesSet('input_indices', 'i', ' input_index')}
369
      }
370
      return input_indices;
371
    }`;
372
const checkInputIndices = (input: IndicesHelper, inputShape: readonly number[]): string => `
373
    fn checkInputIndices(input_indices: ${input.type.indices}) -> bool {
374
      for (var i:u32 = 0; i < ${inputShape.length}; i++) {
375
        var input_index = ${input.indicesGet('input_indices', 'i')};
376
        if (input_index < 0 || input_index >= ${getElementAt('uniforms.input_shape', 'i', inputShape.length)}) {
377
          return false;
378
        }
379
      }
380
      return true;
381
    }`;
382

383
const setChannelAndBatchIndices = (
384
  input: IndicesHelper,
385
  channelIdx: number,
386
  batchIdx: number,
387
  spacialDims: number,
388
): string =>
389
  input.rank > spacialDims
390
    ? `
391
    ${input.indicesSet('input_indices', channelIdx, 'channel')};
392
    ${input.indicesSet('input_indices', batchIdx, 'batch')};
393
`
394
    : '';
395

396
const bilinearInterpolation = (
397
  input: IndicesHelper,
398
  output: IndicesHelper,
399
  inputShape: readonly number[],
400
  useExtrapolation: boolean,
401
  extrapolationValue: number,
402
): string => {
403
  const isNchw = true;
404
  const [batchIdx, heightIdx, widthIdx, channelIdx] =
405
    inputShape.length === 2 ? [-1, 0, 1, -1] : isNchw ? [0, 2, 3, 1] : [0, 1, 2, 3];
406
  const dType = input.type.value;
407
  return `
408
    fn getInputValue(batch: u32, channel: u32, row: u32, col: u32) -> ${dType} {
409
      var input_indices: ${input.type.indices};
410
      ${input.indicesSet('input_indices', heightIdx, `max(0, min(row, ${inputShape[heightIdx]} - 1))`)};
411
      ${input.indicesSet('input_indices', widthIdx, `max(0, min(col, ${inputShape[widthIdx]} - 1))`)};
412
      ${setChannelAndBatchIndices(input, channelIdx, batchIdx, 2)}
413
      return ${input.getByIndices('input_indices')};
414
    }
415

416
    fn bilinearInterpolation(output_indices: ${output.type.indices}) -> ${dType} {
417
      var originalIndices = calculateOriginalIndicesFromOutputIndices(output_indices);
418
      var row:${dType} = originalIndices[${heightIdx}];
419
      var col:${dType} = originalIndices[${widthIdx}];
420
      ${
421
        useExtrapolation
422
          ? `if (row < 0 || row > (${inputShape[heightIdx]} - 1) || col < 0 || col > (${inputShape[widthIdx]} - 1)) {
423
        return ${extrapolationValue};
424
      }`
425
          : ''
426
      };
427
      row = max(0, min(row, ${inputShape[heightIdx]} - 1));
428
      col = max(0, min(col, ${inputShape[widthIdx]} - 1));
429
      var row1: u32 = u32(row);
430
      var col1: u32 = u32(col);
431
      var row2: u32 = u32(row + 1);
432
      var col2: u32 = u32(col + 1);
433
      var channel: u32 = ${inputShape.length > 2 ? `u32(originalIndices[${channelIdx}])` : '0'};
434
      var batch: u32 =  ${inputShape.length > 2 ? `u32(originalIndices[${batchIdx}])` : '0'};
435
      var x11: ${dType} = getInputValue(batch, channel, row1, col1);
436
      var x12: ${dType} = getInputValue(batch, channel, row1, col2);
437
      var x21: ${dType} = getInputValue(batch, channel, row2, col1);
438
      var x22: ${dType} = getInputValue(batch, channel, row2, col2);
439
      var dx1: ${dType} = abs(row - ${dType}(row1));
440
      var dx2: ${dType} = abs(${dType}(row2) - row);
441
      var dy1: ${dType} = abs(col - ${dType}(col1));
442
      var dy2: ${dType} = abs(${dType}(col2) - col);
443
      if (row1 == row2) {
444
        dx1 = 0.5;
445
        dx2 = 0.5;
446
      }
447
      if (col1 == col2) {
448
        dy1 = 0.5;
449
        dy2 = 0.5;
450
      }
451
      return (x11 * dx2 * dy2 + x12 * dx2 * dy1 + x21 * dx1 * dy2 + x22 * dx1 * dy1);
452
    }`;
453
};
454

455
const bicubicInterpolation = (
456
  input: IndicesHelper,
457
  output: IndicesHelper,
458
  inputShape: readonly number[],
459
  outputShape: readonly number[],
460
  scales: readonly number[],
461
  roi: readonly number[],
462
  cubicCoeffA: number,
463
  useExtrapolation: boolean,
464
  extrapolationValue: number,
465
  excludeOutside: boolean,
466
): string => {
467
  const is2D = inputShape.length === 2;
468
  const isNchw = true;
469
  const [heightIdx, widthIdx] = is2D ? [0, 1] : isNchw ? [2, 3] : [1, 2];
470
  const dType = input.type.value;
471
  const createCubicInterpolationFunction = (idx: number): string => {
472
    const direction = idx === heightIdx ? 'row' : 'col';
473
    return `
474
      fn ${direction}CubicInterpolation(input_indices: ${input.type.indices}, output_indices: ${
475
        output.type.indices
476
      }) -> ${dType} {
477
        var output_index = ${output.indicesGet('output_indices', idx)};
478
        var originalIdx: ${dType} = getOriginalCoordinateFromResizedCoordinate(output_index, ${scales[idx]},
479
        ${outputShape[idx]}, ${inputShape[idx]}, ${roi[idx]}, ${roi[idx]} + ${inputShape.length});
480
        var fractOriginalIdx: ${dType} = originalIdx - floor(originalIdx);
481
        var coefs = getCubicInterpolationCoefs(fractOriginalIdx);
482

483
        if (${useExtrapolation} && (originalIdx < 0 || originalIdx > (${inputShape[idx]} - 1))) {
484
          return ${extrapolationValue};
485
        }
486
        var data: array<${dType}, 4> = array<${dType}, 4>(0.0, 0.0, 0.0, 0.0);
487
        for (var i: i32 = -1; i < 3; i++) {
488
          var ${direction}: ${dType} = originalIdx + ${dType}(i);
489
          if (${direction} < 0 || ${direction} >= ${inputShape[idx]}) {
490
            ${(() => {
491
              if (excludeOutside) {
492
                return `coefs[i + 1] = 0.0;
493
                        continue;`;
494
              } else if (useExtrapolation) {
495
                return `return ${extrapolationValue};`;
496
              } else {
497
                return `${direction} = max(0, min(${direction}, ${inputShape[idx]} - 1));`;
498
              }
499
            })()};
500
          }
501
        var input_indices_copy: ${input.type.indices} = input_indices;
502
          ${input.indicesSet('input_indices_copy', idx, `u32(${direction})`)};
503
          data[i + 1] = ${
504
            idx === heightIdx
505
              ? input.getByIndices('input_indices_copy')
506
              : 'rowCubicInterpolation(input_indices_copy, output_indices)'
507
          };
508
        }
509
        return cubicInterpolation1D(data, coefs);
510
      }`;
511
  };
512

513
  return `
514
    ${createCubicInterpolationFunction(heightIdx)};
515
    ${createCubicInterpolationFunction(widthIdx)};
516
  fn getCubicInterpolationCoefs(s: ${dType}) -> array<${dType}, 4> {
517
    var absS = abs(s);
518
    var coeffs: array<${dType}, 4> = array<${dType}, 4>(0.0, 0.0, 0.0, 0.0);
519
    var oneMinusAbsS: ${dType} = 1.0 - absS;
520
    var twoMinusAbsS: ${dType} = 2.0 - absS;
521
    var onePlusAbsS: ${dType} = 1.0 + absS;
522
    coeffs[0] = ((${cubicCoeffA} * onePlusAbsS - 5 * ${cubicCoeffA}) * onePlusAbsS + 8 * ${
523
      cubicCoeffA
524
    }) * onePlusAbsS - 4 * ${cubicCoeffA};
525
    coeffs[1] = ((${cubicCoeffA} + 2) * absS - (${cubicCoeffA} + 3)) * absS * absS + 1;
526
    coeffs[2] = ((${cubicCoeffA} + 2) * oneMinusAbsS - (${cubicCoeffA} + 3)) * oneMinusAbsS * oneMinusAbsS + 1;
527
    coeffs[3] = ((${cubicCoeffA} * twoMinusAbsS - 5 * ${cubicCoeffA}) * twoMinusAbsS + 8 * ${
528
      cubicCoeffA
529
    }) * twoMinusAbsS - 4 * ${cubicCoeffA};
530
    return coeffs;
531
  }
532

533
  fn cubicInterpolation1D(x: array<${dType}, 4>, coefs: array<${dType}, 4>) -> ${dType} {
534
    var coefsSum: ${dType} = coefs[0] + coefs[1] + coefs[2] + coefs[3];
535
    return (x[0] * coefs[0] + x[1] * coefs[1]+ x[2] * coefs[2]+ x[3] * coefs[3]) / coefsSum;
536
  }
537

538
  fn bicubicInterpolation(output_indices: ${output.type.indices}) -> ${dType} {
539
    var input_indices: ${input.type.indices} = output_indices;
540
    return colCubicInterpolation(input_indices, output_indices);
541
  }
542
    `;
543
};
544

545
const trilinearInterpolation = (
546
  input: IndicesHelper,
547
  output: IndicesHelper,
548
  inputShape: readonly number[],
549
  useExtrapolation: boolean,
550
  extrapolationValue: number,
551
): string => {
552
  const isNchw = true;
553
  const [batchIdx, depthIdx, heightIdx, widthIdx, channelIdx] =
554
    inputShape.length === 3 ? [-1, 0, 1, 2, -1] : isNchw ? [0, 2, 3, 4, 1] : [0, 1, 2, 3, 4];
555
  const dType = input.type.value;
556
  return `
557
    fn getInputValue(batch: u32, channel: u32, depth:u32, height: u32, width: u32) -> ${dType} {
558
      var input_indices: ${input.type.indices};
559
      ${input.indicesSet('input_indices', depthIdx, `max(0, min(depth, ${inputShape[depthIdx]} - 1))`)};
560
      ${input.indicesSet('input_indices', heightIdx, `max(0, min(height, ${inputShape[heightIdx]} - 1))`)};
561
      ${input.indicesSet('input_indices', widthIdx, `max(0, min(width, ${inputShape[widthIdx]} - 1))`)};
562
      ${setChannelAndBatchIndices(input, channelIdx, batchIdx, 3)}
563
      return ${input.getByIndices('input_indices')};
564
    }
565

566
    fn trilinearInterpolation(output_indices: ${output.type.indices}) -> ${dType} {
567
      var originalIndices = calculateOriginalIndicesFromOutputIndices(output_indices);
568
      var depth:${dType} = originalIndices[${depthIdx}];
569
      var height:${dType} = originalIndices[${heightIdx}];
570
      var width:${dType} = originalIndices[${widthIdx}];
571
      ${
572
        useExtrapolation
573
          ? `if (depth < 0 || depth > (${inputShape[depthIdx]} - 1) || height < 0 || height > (${
574
              inputShape[heightIdx]
575
            } - 1) || width < 0 || (width > ${inputShape[widthIdx]} - 1)) {
576
      return ${extrapolationValue};
577
        }`
578
          : ''
579
      };
580

581
    depth = max(0, min(depth, ${inputShape[depthIdx]} - 1));
582
      height = max(0, min(height, ${inputShape[heightIdx]} - 1));
583
      width = max(0, min(width, ${inputShape[widthIdx]} - 1));
584
      var depth1: u32 = u32(depth);
585
      var height1: u32 = u32(height);
586
      var width1: u32 = u32(width);
587
      var depth2: u32 = u32(depth + 1);
588
      var height2: u32 = u32(height + 1);
589
      var width2: u32 = u32(width + 1);
590
      var channel: u32 = ${inputShape.length > 3 ? `u32(originalIndices[${channelIdx}])` : '0'};
591
      var batch: u32 =  ${inputShape.length > 3 ? `u32(originalIndices[${batchIdx}])` : '0'};
592

593
      var x111: ${dType} = getInputValue(batch, channel, depth1, height1, width1);
594
      var x112: ${dType} = getInputValue(batch, channel, depth1, height1, width2);
595
      var x121: ${dType} = getInputValue(batch, channel, depth1, height2, width1);
596
      var x122: ${dType} = getInputValue(batch, channel, depth1, height2, width2);
597
      var x211: ${dType} = getInputValue(batch, channel, depth2, height1, width1);
598
      var x212: ${dType} = getInputValue(batch, channel, depth2, height1, width2);
599
      var x221: ${dType} = getInputValue(batch, channel, depth2, height2, width1);
600
      var x222: ${dType} = getInputValue(batch, channel, depth2, height2, width2);
601
      var dx1: ${dType} = abs(depth - ${dType}(depth1));
602
      var dx2: ${dType} = abs(${dType}(depth2) - depth);
603
      var dy1: ${dType} = abs(height - ${dType}(height1));
604
      var dy2: ${dType} = abs(${dType}(height2) - height);
605
      var dz1: ${dType} = abs(width - ${dType}(width1));
606
      var dz2: ${dType} = abs(${dType}(width2) - width);
607
      if (depth1 == depth2) {
608
        dx1 = 0.5;
609
        dx2 = 0.5;
610
      }
611
      if (height1 == height2) {
612
        dy1 = 0.5;
613
        dy2 = 0.5;
614
      }
615
      if (width1 == width2) {
616
        dz1 = 0.5;
617
        dz2 = 0.5;
618
      }
619
      return (x111 * dx2 * dy2 * dz2 + x112 * dx2 * dy2 * dz1 + x121 * dx2 * dy1 *dz2 + x122 * dx2 * dy1 * dz1 +
620
              x211 * dx1 * dy2 * dz2 + x212 * dx1 * dy2 * dz1 + x221 * dx1 * dy1 *dz2 + x222 * dx1 * dy1 * dz1);
621
    }`;
622
};
623

624
const createResizeProgramInfo = (
625
  inputTensor: TensorView,
626
  attributes: ResizeAttributes,
627
  opsetVersion: number,
628
  scalesInput: readonly number[],
629
  sizes: readonly number[],
630
  roiInput: readonly number[],
631
): ProgramInfo => {
632
  const inputShape = inputTensor.dims;
633
  const roi = updateRoI(roiInput, attributes.axes, inputShape.length);
634

635
  let outputShape = initOutputShape(inputShape, scalesInput, sizes, attributes.axes);
636
  let scales = scalesInput.slice();
637
  if (scalesInput.length === 0) {
638
    scales = inputShape.map((value, index) => (value === 0 ? 1.0 : outputShape[index] / value));
639
    if (attributes.keepAspectRatioPolicy !== 'stretch') {
640
      outputShape = adjustOutputShape(inputShape, scales, attributes);
641
    }
642
  }
643
  const output = outputVariable('output', inputTensor.dataType, outputShape.length);
644
  const input = inputVariable('input', inputTensor.dataType, inputShape.length);
645
  const outputSize = ShapeUtil.size(outputShape);
646
  const noScale = inputShape.length === outputShape.length && inputShape.every((d, i) => d === outputShape[i]);
647
  const useExtrapolation = attributes.coordinateTransformMode === 'tf_crop_and_resize';
648
  const extrapolationValue = attributes.extrapolationValue;
649
  const dataType = input.type.value;
650
  const getShaderSource = (shaderHelper: ShaderHelper) => `
651
      ${
652
        noScale
653
          ? ''
654
          : `
655
      ${getOriginalCoordinateFromResizedCoordinate(attributes.coordinateTransformMode, dataType)};
656
      ${(() => {
657
        switch (attributes.mode) {
658
          case 'nearest':
659
            return `
660
              ${checkInputIndices(input, inputShape)};
661
              ${getNearestPixelFromOriginal(attributes.nearestMode, opsetVersion, dataType)};
662
              ${calculateInputIndicesFromOutputIndices(
663
                input,
664
                output,
665
                inputShape,
666
                outputShape,
667
                scales.length,
668
                roi.length,
669
                useExtrapolation,
670
              )};
671
              `;
672
          case 'linear':
673
            return `
674
              ${calculateOriginalIndicesFromOutputIndices(output, inputShape, outputShape, scales.length, roi.length)};
675
              ${(() => {
676
                if (inputShape.length === 2 || inputShape.length === 4) {
677
                  return `${bilinearInterpolation(input, output, inputShape, useExtrapolation, extrapolationValue)}`;
678
                } else if (inputShape.length === 3 || inputShape.length === 5) {
679
                  return `${trilinearInterpolation(input, output, inputShape, useExtrapolation, extrapolationValue)}`;
680
                } else {
681
                  throw Error('Linear mode only supports input dims 2, 3, 4 and 5 are supported in linear mode.');
682
                }
683
              })()};
684
            `;
685
          case 'cubic':
686
            return `
687
            ${(() => {
688
              if (inputShape.length === 2 || inputShape.length === 4) {
689
                return `${bicubicInterpolation(
690
                  input,
691
                  output,
692
                  inputShape,
693
                  outputShape,
694
                  scales,
695
                  roi,
696
                  attributes.cubicCoeffA,
697
                  useExtrapolation,
698
                  attributes.extrapolationValue,
699
                  attributes.excludeOutside,
700
                )}`;
701
              } else {
702
                throw Error('Cubic mode only supports input dims 2 and 4 are supported in linear mode.');
703
              }
704
            })()};
705
            `;
706
          default:
707
            throw Error('Invalid resize mode');
708
        }
709
      })()};
710
      `
711
      }
712
      ${shaderHelper
713
        .registerUniform('output_size', 'u32')
714
        .registerUniform('scales', 'f32', scales.length)
715
        .registerUniform('roi', 'f32', roi.length)
716
        .declareVariables(input, output)}
717
      ${shaderHelper.mainStart()}
718
        ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
719
        ${
720
          noScale
721
            ? 'output[global_idx] = input[global_idx];'
722
            : `
723
        let output_indices = ${output.offsetToIndices('global_idx')};
724
        var input_indices: ${input.type.indices};
725
        ${(() => {
726
          switch (attributes.mode) {
727
            case 'nearest':
728
              return `input_indices = calculateInputIndicesFromOutputIndices(output_indices);
729
                if (checkInputIndices(input_indices)) {
730
                  output[global_idx] = ${input.getByIndices('input_indices')};
731
                } else {
732
                  output[global_idx] = ${attributes.extrapolationValue};
733
                }`;
734
            case 'linear':
735
              return `output[global_idx] = ${
736
                inputShape.length === 2 || inputShape.length === 4 ? 'bilinearInterpolation' : 'trilinearInterpolation'
737
              }(output_indices);`;
738
            case 'cubic':
739
              return 'output[global_idx] = bicubicInterpolation(output_indices);';
740
            default:
741
              throw Error(`Unsupported resize mode: ${attributes.mode}`);
742
          }
743
        })()};
744
`
745
        }
746
      }`;
747

748
  return {
749
    name: 'Resize',
750
    shaderCache: {
751
      hint: `${attributes.cacheKey}|${opsetVersion}|${scales.length > 0 ? scales : ''}|${
752
        sizes.length > 0 ? sizes : ''
753
      }|${roi.length > 0 ? roi : ''}|${noScale}|${inputShape}`,
754
      inputDependencies: ['rank'],
755
    },
756
    getShaderSource,
757
    getRunData: () => ({
758
      outputs: [{ dims: outputShape, dataType: inputTensor.dataType }],
759
      dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
760
      programUniforms: [
761
        { type: DataType.uint32, data: outputSize },
762
        { type: DataType.float, data: scales },
763
        { type: DataType.float, data: roi },
764
        ...createTensorShapeVariables(inputShape, outputShape),
765
      ],
766
    }),
767
  };
768
};
769

770
const getOpsetVersionFromCustomDataBuffer = (context: ComputeContext): number => {
771
  const customDataBuffer = context.customDataBuffer;
772
  const customDataBuffer32 = new Uint32Array(customDataBuffer, customDataBuffer.byteOffset, 1);
773
  const opsetVersion = customDataBuffer32[0];
774
  return opsetVersion;
775
};
776

777
export const resize = (context: ComputeContext, attributes: ResizeAttributes): void => {
778
  const scales: number[] = [];
779
  const sizes: number[] = [];
780
  const roi: number[] = [];
781

782
  // Note that scales in resize are always f32. roi can be f32 or f16.
783
  // TODO: Currently this code does not support f16 for roi when passed as optional input.
784

785
  const opsetVersion = getOpsetVersionFromCustomDataBuffer(context);
786
  if (attributes.antialias !== 0) {
787
    throw Error('Only default value (0) for Antialias attribute is supported');
788
  }
789
  validateInputs(context.inputs, attributes, opsetVersion, scales, sizes, roi);
790
  context.compute(createResizeProgramInfo(context.inputs[0], attributes, opsetVersion, scales, sizes, roi), {
791
    inputs: [0],
792
  });
793
};
794

795
export const parseResizeAttributes = (attributes: Record<string, unknown>): ResizeAttributes => {
796
  const antialias = attributes.antialias as number;
797
  const axes = attributes.axes as number[];
798
  const coordinateTransformMode: CoordinateTransformMode =
799
    attributes.coordinateTransformMode as CoordinateTransformMode;
800
  const cubicCoeffA = attributes.cubicCoeffA as number;
801
  const excludeOutside = (attributes.excludeOutside as number) !== 0;
802
  const extrapolationValue = attributes.extrapolationValue as number;
803
  const keepAspectRatioPolicy: KeepAspectRatioPolicy = attributes.keepAspectRatioPolicy as KeepAspectRatioPolicy;
804
  const mode: Mode = attributes.mode as Mode;
805
  // If nearestMode is not specified, use simple mode.
806
  const nearestMode: NearestMode = (attributes.nearestMode === '' ? 'simple' : attributes.nearestMode) as NearestMode;
807
  return createAttributeWithCacheKey({
808
    antialias,
809
    axes,
810
    coordinateTransformMode,
811
    cubicCoeffA,
812
    excludeOutside,
813
    extrapolationValue,
814
    keepAspectRatioPolicy,
815
    mode,
816
    nearestMode,
817
  });
818
};
819

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.