onnxruntime

Форк
0
292 строки · 11.5 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { Graph } from '../../../graph';
5
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
6
import { Tensor } from '../../../tensor';
7
import { getGlsl } from '../glsl-source';
8
import { WebGLInferenceHandler } from '../inference-handler';
9
import { ProgramInfo, TextureType } from '../types';
10
import { getCoordsDataType } from '../utils';
11

12
import { unpackFromChannel } from './packing-utils';
13
import { parseUpsampleAttributes, scalesValidation, UpsampleAttributes, validateInputs } from './upsample';
14

15
const resizeProgramMetadata = {
16
  name: 'Resize',
17
  inputNames: ['A'],
18
  inputTypes: [TextureType.packed],
19
};
20

21
export const resize: OperatorImplementation<UpsampleAttributes> = (
22
  inferenceHandler: WebGLInferenceHandler,
23
  inputs: Tensor[],
24
  attributes: UpsampleAttributes,
25
): Tensor[] => {
26
  validateInputs(inputs, attributes);
27
  const output = inferenceHandler.run(
28
    {
29
      ...resizeProgramMetadata,
30
      cacheHint: attributes.cacheKey,
31
      get: () => createPackedResizeProgramInfo(inferenceHandler, inputs, attributes),
32
    },
33
    inputs,
34
  );
35
  return [output];
36
};
37

38
export const parseResizeAttributesV10: OperatorInitialization<UpsampleAttributes> = (
39
  node: Graph.Node,
40
): UpsampleAttributes => parseUpsampleAttributes(node, 10);
41

42
export const parseResizeAttributesV11: OperatorInitialization<UpsampleAttributes> = (
43
  node: Graph.Node,
44
): UpsampleAttributes => parseUpsampleAttributes(node, 11);
45

46
const createPackedResizeProgramInfo = (
47
  inferenceHandler: WebGLInferenceHandler,
48
  inputs: Tensor[],
49
  attributes: UpsampleAttributes,
50
): ProgramInfo => {
51
  const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
52
  const [scales, outputShape] = prepareInputs(inputs, attributes);
53

54
  const isSame = scales.every((s: number) => s === 1) && attributes.coordinateTransformMode !== 'tf_crop_and_resize';
55
  if (isSame) {
56
    return {
57
      ...resizeProgramMetadata,
58
      output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.packed },
59
      hasMain: true,
60
      shaderSource: `void main() {
61
                    vec4 v = ${glsl.texture2D}(X, TexCoords);
62
                    ${glsl.output} = v;
63
                }`,
64
    };
65
  }
66

67
  const dim = outputShape.length;
68
  if (dim < 2) {
69
    throw new Error(`output dimension should be at least 2, but got ${dim}`);
70
  }
71

72
  const outputHeight = outputShape[dim - 2];
73
  const outputWidth = outputShape[dim - 1];
74

75
  const inputShape = inputs[0].dims;
76
  if (dim !== inputShape.length) {
77
    throw new Error(`output dimension should match input ${inputShape.length}, but got ${dim}`);
78
  }
79
  const inputHeight = inputShape[dim - 2];
80
  const inputWidth = inputShape[dim - 1];
81

82
  const scalesHeight = scales[dim - 2];
83
  const scalesWidth = scales[dim - 1];
84

85
  let getSourceFracIndex = '';
86

87
  if (attributes.mode !== 'linear') {
88
    // TODO: support other modes
89
    throw new Error(`resize (packed) does not support mode: '${attributes.mode}'`);
90
  }
91
  switch (attributes.coordinateTransformMode) {
92
    case 'asymmetric':
93
      getSourceFracIndex = `
94
                    vec4 getSourceFracIndex(ivec4 coords) {
95
                        return vec4(coords) / scaleWHWH;
96
                    }
97
                `;
98
      break;
99
    case 'half_pixel':
100
      getSourceFracIndex = `
101
                    vec4 getSourceFracIndex(ivec4 coords) {
102
                        return (vec4(coords) + 0.5) / scaleWHWH - 0.5;
103
                    }
104
                `;
105
      break;
106
    case 'pytorch_half_pixel':
107
      getSourceFracIndex = `
108
                    vec4 getSourceFracIndex(ivec4 coords) {
109
                        vec4 fcoords = vec4(coords);
110
                        return vec4(
111
                            ${outputWidth}.0 > 1.0 ? (fcoords.x + 0.5) / scaleWHWH.x - 0.5 : 0.0,
112
                            ${outputHeight}.0 > 1.0 ? (fcoords.y + 0.5) / scaleWHWH.y - 0.5 : 0.0,
113
                            ${outputWidth}.0 > 1.0 ? (fcoords.z + 0.5) / scaleWHWH.z - 0.5 : 0.0,
114
                            ${outputHeight}.0 > 1.0 ? (fcoords.w + 0.5) / scaleWHWH.w - 0.5 : 0.0
115
                          );
116
                    }
117
                `;
118
      break;
119
    case 'align_corners':
120
      getSourceFracIndex = `
121
                    vec4 getSourceFracIndex(ivec4 coords) {
122
                        vec4 resized = vec4(${outputWidth}.0 - 1.0, ${outputHeight}.0 - 1.0, ${outputWidth}.0 - 1.0,
123
                            ${outputHeight}.0 - 1.0);
124
                        vec4 original = vec4(${inputWidth}.0 - 1.0, ${inputHeight}.0 - 1.0, ${inputWidth}.0 - 1.0,
125
                            ${inputHeight}.0 - 1.0);
126
                        vec4 new_scale = original / resized;
127
                        return vec4(coords) * new_scale;
128
                    }
129
                `;
130
      break;
131
    default:
132
      // TODO:supporting other coordinateTransformModes
133
      throw new Error(`resize (packed) does not support coordinateTransformMode: \
134
                                '${attributes.coordinateTransformMode}'`);
135
  }
136

137
  const coordsDataType = getCoordsDataType(dim);
138
  const unpackChannel = unpackFromChannel();
139
  const shaderSource = `
140
            const vec2 inputWH = vec2(${inputHeight}.0, ${inputWidth}.0);
141
            const vec4 scaleWHWH = vec4(float(${scalesHeight}), float(${scalesWidth}), float(${scalesHeight}), float(${
142
              scalesWidth
143
            }));
144
            ${unpackChannel}
145
            ${getSourceFracIndex}
146
            float getAValue(int x10, int r, int c, int d) {
147
                return getChannel(getA(x10, r, c, d), vec2(c, d));
148
            }
149
            void main() {
150
                ${coordsDataType} rc = getOutputCoords();
151

152
                int batch = rc[0];
153
                int depth = rc[1];
154

155
                // retrieve the 4 coordinates that is used in the 4 packed output values.
156
                ivec4 coords = ivec4(rc.wz, rc.w + 1, rc.z + 1);
157

158
                // calculate the source index in fraction
159
                vec4 sourceFrac = getSourceFracIndex(coords);
160

161
                // get the lower and upper bound of the 4 values that will be packed into one texel.
162
                ivec4 x00 = ivec4(max(sourceFrac.xy, vec2(0.0)), min(inputWH - 1.0, ceil(sourceFrac.xy)));
163
                ivec4 x01 = ivec4(max(sourceFrac.xw, vec2(0.0)), min(inputWH - 1.0, ceil(sourceFrac.xw)));
164
                ivec4 x10 = ivec4(max(sourceFrac.zy, vec2(0.0)), min(inputWH - 1.0, ceil(sourceFrac.zy)));
165
                ivec4 x11 = ivec4(max(sourceFrac.zw, vec2(0.0)), min(inputWH - 1.0, ceil(sourceFrac.zw)));
166

167
                bool hasNextRow = rc.w < ${outputHeight - 1};
168
                bool hasNextCol = rc.z < ${outputWidth - 1};
169

170
                // pack x00, x01, x10, x11's top-left corner into one vec4 structure
171
                vec4 topLeft = vec4(
172
                    getAValue(batch, depth, x00.x, x00.y),
173
                    hasNextCol ? getAValue(batch, depth, x01.x, x01.y) : 0.0,
174
                    hasNextRow ? getAValue(batch, depth, x10.x, x10.y) : 0.0,
175
                    (hasNextRow && hasNextCol) ? getAValue(batch, depth, x11.x, x11.y) : 0.0);
176

177
                // pack x00, x01, x10, x11's top-right corner into one vec4 structure
178
                vec4 topRight = vec4(
179
                    getAValue(batch, depth, x00.x, x00.w),
180
                    hasNextCol ? getAValue(batch, depth, x01.x, x01.w) : 0.0,
181
                    hasNextRow ? getAValue(batch, depth, x10.x, x10.w) : 0.0,
182
                    (hasNextRow && hasNextCol) ? getAValue(batch, depth, x11.x, x11.w) : 0.0);
183

184
                // pack x00, x01, x10, x11's bottom-left corner into one vec4 structure
185
                vec4 bottomLeft = vec4(
186
                    getAValue(batch, depth, x00.z, x00.y),
187
                    hasNextCol ? getAValue(batch, depth, x01.z, x01.y) : 0.0,
188
                    hasNextRow ? getAValue(batch, depth, x10.z, x10.y) : 0.0,
189
                    (hasNextRow && hasNextCol) ? getAValue(batch, depth, x11.z, x11.y) : 0.0);
190

191
                // pack x00, x01, x10, x11's bottom-right corner into one vec4 structure
192
                vec4 bottomRight = vec4(
193
                    getAValue(batch, depth, x00.z, x00.w),
194
                    hasNextCol ? getAValue(batch, depth, x01.z, x01.w) : 0.0,
195
                    hasNextRow ? getAValue(batch, depth, x10.z, x10.w) : 0.0,
196
                    (hasNextRow && hasNextCol) ? getAValue(batch, depth, x11.z, x11.w) : 0.0);
197

198
                // calculate the interpolation fraction on u and v direction
199
                vec4 frac = vec4(sourceFrac) - floor(sourceFrac);
200
                vec4 clampFrac = clamp(frac, vec4(0.0), vec4(1.0));
201

202
                vec4 top = mix(topLeft, topRight, clampFrac.ywyw);
203
                vec4 bottom = mix(bottomLeft, bottomRight, clampFrac.ywyw);
204
                vec4 newValue = mix(top, bottom, clampFrac.xxzz);
205

206
                ${glsl.output} = vec4(newValue);
207
            }
208
        `;
209
  return {
210
    ...resizeProgramMetadata,
211
    output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.packed },
212
    hasMain: true,
213
    shaderSource,
214
  };
215
};
216

217
const prepareInputs = (inputs: Tensor[], attributes: UpsampleAttributes): [readonly number[], readonly number[]] => {
218
  const x = inputs[0];
219
  const xDims = x.dims;
220

221
  let scales = attributes.scales;
222
  let outputSizes: number[] | undefined;
223
  if (scales.length === 0) {
224
    const scalesTensor = inputs[attributes.scalesInputIdx];
225
    if (scalesTensor && scalesTensor.size !== 0) {
226
      if (inputs[attributes.sizesInputIdx]) {
227
        throw new Error('Only one of scales or sizes must be provided as input.');
228
      }
229
      scales = parseScalesData(scalesTensor, attributes.mode, attributes.isResize);
230
    } else {
231
      const sizesTensor = inputs[attributes.sizesInputIdx];
232
      if (!sizesTensor || sizesTensor.size === 0) {
233
        throw new Error('Either scales or sizes MUST be provided as input.');
234
      }
235

236
      outputSizes = Array.from(sizesTensor.integerData);
237
      scales = parseScalesDataFromOutputSize(outputSizes, xDims, attributes.mode, attributes.isResize);
238
    }
239
  } else {
240
    if (inputs[attributes.sizesInputIdx]) {
241
      throw new Error('Only one of scales or sizes must be provided as input.');
242
    }
243
  }
244

245
  const yDims = outputSizes || xDims.map((dim, i) => Math.floor(dim * scales[i]));
246

247
  return [scales, yDims];
248
};
249

250
const parseScalesData = (scale: Tensor, mode: string, isResize: boolean): number[] => {
251
  const scales = Array.from(scale.floatData);
252
  scalesValidation(scales, mode, isResize);
253
  return scales;
254
};
255

256
const parseScalesDataFromOutputSize = (
257
  yDims: readonly number[],
258
  xDims: readonly number[],
259
  mode: string,
260
  isResize: boolean,
261
): number[] => {
262
  const length = xDims.length;
263
  const scales = new Array<number>(length);
264

265
  for (let i = 0, end = length; i < end; i++) {
266
    if (xDims[i] === 0) {
267
      if (yDims[i] !== 0) {
268
        throw new Error('Input dim is zero but required output dim is non-zero.');
269
      }
270
      scales[i] = 1;
271
    } else {
272
      scales[i] = yDims[i] / xDims[i];
273
    }
274
  }
275
  scalesValidation(scales, mode, isResize);
276
  return scales;
277
};
278

279
// roi data is not used yet. but leave here for future usage.
280
// const getRoi = (inputs: Tensor[], attributes: UpsampleAttributes) : number[] => {
281
//     let roi: number[] = [];
282
//     if (attributes.needRoiInput) {
283
//         if (attributes.roiInputIdx <= 0) {
284
//             throw new Error('Invalid roi input index.');
285
//         }
286
//         const roiTensor = inputs[attributes.roiInputIdx];
287
//         roi = roiTensor.size > 0 ? Array.from(roiTensor.floatData) : [];
288
//     } else {
289
//         roi = new Array(inputs[0].dims.length * 2).fill(0);
290
//     }
291
//     return roi;
292
// };
293

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.