onnxruntime

Форк
0
140 строк · 3.8 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { Tensor } from '../../../tensor';
5
import { getGlsl } from '../glsl-source';
6
import { WebGLInferenceHandler } from '../inference-handler';
7
import { ProgramInfo, ProgramInfoLoader, TextureType } from '../types';
8
import { getCoordsDataType } from '../utils';
9

10
import { getChannels } from './packing-utils';
11

12
const packProgramMetadata = {
13
  name: 'pack',
14
  inputNames: ['A'],
15
  inputTypes: [TextureType.unpackedReversed],
16
};
17

18
const createPackProgramInfo = (handler: WebGLInferenceHandler, input: Tensor): ProgramInfo => {
19
  const glsl = getGlsl(handler.session.backend.glContext.version);
20
  const inputShape = input.dims;
21

22
  const inputRank = inputShape.length;
23
  // createTextureLayoutFromShape won't change output rank. Need to verify by running tests
24
  const outputRank = input.dims.length;
25

26
  const coordsDataType = getCoordsDataType(outputRank);
27
  const channels = getChannels('rc', outputRank);
28
  const setup = getSetup(outputRank, channels, inputShape[inputShape.length - 2], inputShape[inputShape.length - 1]);
29

30
  let reversedInputWH;
31
  if (inputRank === 0) {
32
    reversedInputWH = [1, 1];
33
  } else if (inputRank === 1) {
34
    reversedInputWH = [inputShape[0], 1];
35
  } else {
36
    reversedInputWH = [inputShape[outputRank - 1], inputShape[outputRank - 2]];
37
  }
38
  const outOfBoundsCondition = getOutOfBoundsCondition(outputRank, reversedInputWH, channels);
39
  const output = getOutput(inputShape, channels);
40

41
  const shaderSource = `
42
        void main() {
43
          ${coordsDataType} rc = getOutputCoords();
44

45
          if(${outOfBoundsCondition}) {
46
            ${glsl.output} = vec4(0);
47
          } else {
48
            ${setup}
49

50
            ${glsl.output} = vec4(${output});
51
          }
52
        }
53
      `;
54
  return {
55
    ...packProgramMetadata,
56
    hasMain: true,
57
    output: { dims: input.dims, type: input.type, textureType: TextureType.packed },
58
    shaderSource,
59
  };
60
};
61

62
export const createPackProgramInfoLoader = (handler: WebGLInferenceHandler, input: Tensor): ProgramInfoLoader => ({
63
  ...packProgramMetadata,
64
  get: () => createPackProgramInfo(handler, input),
65
});
66

67
/**
68
 * check output coordinate location and return false if it is outside input's width/height boundary
69
 */
70
function getOutOfBoundsCondition(rank: number, shape: readonly number[], dims: string[]): string {
71
  if (rank === 0) {
72
    return 'false';
73
  }
74
  if (rank === 1) {
75
    return `rc > ${shape[0]}`;
76
  }
77

78
  let cond = '';
79
  for (let i = rank - 2; i < rank; i++) {
80
    cond += `${dims[i]} >= ${shape[i - rank + 2]}`;
81
    if (i < rank - 1) {
82
      cond += '||';
83
    }
84
  }
85

86
  return cond;
87
}
88

89
/**
90
 * code snippet to sample input texture with output coordinates
91
 */
92
function getOutput(shape: readonly number[], dims: string[]): string {
93
  const rank = shape.length;
94

95
  if (rank === 0) {
96
    return 'getA(), 0, 0, 0';
97
  }
98

99
  if (rank === 1) {
100
    return `getA(rc),
101
            rc + 1 >= ${shape[0]} ? 0. : getA(rc + 1),
102
            0, 0`;
103
  }
104

105
  const coord00 = 'r, c';
106
  const coord01 = 'r, cp1';
107
  const coord10 = 'rp1, c';
108
  const coord11 = 'rp1, cp1';
109
  let D = '';
110
  if (rank > 2) {
111
    for (let i = 0; i < rank - 2; ++i) {
112
      D = D + `${dims[i]},`;
113
    }
114
  }
115
  return `getA(${D}${coord00}),
116
          rEdge ? 0. : getA(${D}${coord10}),
117
          cEdge ? 0. : getA(${D}${coord01}),
118
          rEdge || cEdge ? 0. : getA(${D}${coord11})`;
119
}
120

121
/**
122
 * code snippet to setup 4 coordinates and edge conditions
123
 */
124
function getSetup(rank: number, dims: string[], rows: number, cols: number): string {
125
  if (rank === 0 || rank === 1) {
126
    return '';
127
  }
128
  // rank >= 2 for width+height pack.
129
  else {
130
    const setup = `
131
    int r = ${dims[rank - 2]};
132
    int c = ${dims[rank - 1]};
133
    int rp1 = ${dims[rank - 2]} + 1;
134
    int cp1 = ${dims[rank - 1]} + 1;
135
    bool rEdge = rp1 >= ${cols};
136
    bool cEdge = cp1 >= ${rows};
137
    `;
138
    return setup;
139
  }
140
}
141

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.