onnxruntime

Форк
0
172 строки · 5.5 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { Tensor } from '../../../tensor';
5
import { ShapeUtil } from '../../../util';
6
import { getGlsl } from '../glsl-source';
7
import { WebGLInferenceHandler } from '../inference-handler';
8
import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types';
9

10
import { unpackFromChannel } from './packing-utils';
11

12
const createPackedReshape3DProgramMetadata = (outputShape3D: readonly number[]) => ({
13
  name: 'Reshape (packed)',
14
  inputTypes: [TextureType.packed],
15
  inputNames: ['A'],
16
  cacheHint: `${outputShape3D}`,
17
});
18

19
const createPackedReshape3DProgramInfo = (
20
  handler: WebGLInferenceHandler,
21
  input3D: Tensor,
22
  metadata: ProgramMetadata,
23
  outputShape3D: readonly number[],
24
): ProgramInfo => {
25
  const inputShape3D = input3D.dims as [number, number, number];
26
  const squeezedOutputShape = outputShape3D as [number, number, number];
27

28
  let mainLoop = '';
29
  for (let i = 0; i < 4; i++) {
30
    let outputCoords = '';
31
    switch (i) {
32
      case 0:
33
        outputCoords = 'outputCoords = rc;';
34
        break;
35
      case 1:
36
        outputCoords = 'outputCoords = ivec3(rc.x, rc.y+1, rc.z);';
37
        break;
38
      case 2:
39
        outputCoords = 'outputCoords = ivec3(rc.x, rc.y, rc.z+1);';
40
        break;
41
      case 3:
42
        outputCoords = 'outputCoords = ivec3(rc.x, rc.y+1, rc.z+1);';
43
        break;
44
      default:
45
        throw new Error();
46
    }
47

48
    mainLoop += `
49
        ${outputCoords}
50
        ${i > 0 ? 'if(outputCoords.y < rows && outputCoords.z < cols){' : ''}
51
          int flattenedIndex = getFlattenedIndex(outputCoords);
52

53
          ivec3 inputRC = inputCoordsFromReshapedOutCoords(flattenedIndex);
54
          vec2 innerDims = vec2(float(inputRC.y),float(inputRC.z));
55

56
          result[${i}] = getChannel(getA(inputRC.x, inputRC.y, inputRC.z), innerDims);
57

58
        ${i > 0 ? '}' : ''}
59
      `;
60
  }
61
  const glsl = getGlsl(handler.session.backend.glContext.version);
62

63
  const shaderSource = `
64
      ${getReshapedInputCoords(inputShape3D)}
65
      ${getFlattenedIndexFrom3D(squeezedOutputShape)}
66
      ${unpackFromChannel()}
67

68
      void main() {
69
        ivec3 rc = getOutputCoords();
70

71
        vec4 result = vec4(0.0);
72

73
        ivec3 outputCoords;
74
        int rows = ${squeezedOutputShape[2]};
75
        int cols = ${squeezedOutputShape[1]};
76

77
        ${mainLoop}
78
        ${glsl.output} = result;
79
      }
80
    `;
81

82
  return {
83
    ...metadata,
84
    output: { dims: squeezedOutputShape, type: input3D.type, textureType: TextureType.packed },
85
    shaderSource,
86
    hasMain: true,
87
  };
88
};
89

90
export const createPackedReshape3DProgramInfoLoader = (
91
  handler: WebGLInferenceHandler,
92
  input3D: Tensor,
93
  outputShape3D: readonly number[],
94
): ProgramInfoLoader => {
95
  const metadata = createPackedReshape3DProgramMetadata(outputShape3D);
96
  return { ...metadata, get: () => createPackedReshape3DProgramInfo(handler, input3D, metadata, outputShape3D) };
97
};
98

99
export function processDims3D(shape: ArrayLike<number>): [number, number, number] {
100
  if (shape.length === 0) {
101
    return [1, 1, 1];
102
  }
103
  // TODO: squeeze other shapes to 2D case
104
  let batch = 1;
105
  for (let i = 0; i < shape.length - 2; ++i) {
106
    batch *= shape[i];
107
  }
108
  return [batch, shape.length > 1 ? shape[shape.length - 2] : 1, shape[shape.length - 1]];
109
}
110

111
// For packed reshape, we need to re-arrange texel data for output shape.
112
// Our pack is designed to pack a 2x2 tile in last h and w dimension, so
113
// for the reshaped new tensor, we just need to re-arrange the last h and
114
// w dimension. For any shape that is not in 3D, i.e. [batch, W, H], we
115
// first convert it to 3D by collapsing other dimension to batch dim, then
116
// process with the last two dimensions.
117
// Note: we only need the shape tensor to calculate output shape, so the
118
// content in shape tensor is never uploaded to GPU. It is always kept in CPU.
119
// TODO: optimize the algorithm -- in some cases, if the last two dims are
120
// the same between input shape and output shape, the packed reshape can be
121
// treated as no-op.
122
export function isReshapeCheap(dims: readonly number[], reshapedDims: readonly number[]) {
123
  let isCheapReshape = false;
124
  if (dims.length === 0 || reshapedDims.length === 0) {
125
    // scalar
126
    isCheapReshape = true;
127
  } else if (dims.length < 2 || reshapedDims.length < 2) {
128
    // 1D
129
    isCheapReshape = dims[dims.length - 1] === reshapedDims[reshapedDims.length - 1];
130
  } else {
131
    // 2D +
132
    isCheapReshape =
133
      dims[dims.length - 1] === reshapedDims[reshapedDims.length - 1] &&
134
      dims[dims.length - 2] === reshapedDims[reshapedDims.length - 2];
135
  }
136

137
  return isCheapReshape;
138
}
139

140
function getReshapedInputCoords(shape: [number, number, number]): string {
141
  const strides = ShapeUtil.computeStrides(shape);
142
  const coords = ['b', 'r', 'c'];
143
  const index = 'index';
144
  const coordsFromIndexSnippet = strides
145
    .map((stride, i) => {
146
      const line1 = `int ${coords[i]} = ${index} / ${stride}`;
147
      const line2 =
148
        i === strides.length - 1
149
          ? `int ${coords[i + 1]} = ${index} - ${coords[i]} * ${stride}`
150
          : `index -= ${coords[i]} * ${stride}`;
151
      return `${line1}; ${line2};`;
152
    })
153
    .join('');
154

155
  return `
156
    ivec3 inputCoordsFromReshapedOutCoords(int index) {
157
      ${coordsFromIndexSnippet}
158
      return ivec3(b, r, c);
159
    }
160
  `;
161
}
162

163
function getFlattenedIndexFrom3D(shape: [number, number, number]): string {
164
  const strides = ShapeUtil.computeStrides(shape);
165

166
  return `
167
  int getFlattenedIndex(ivec3 coords) {
168
    // reverse y, z order
169
    return coords.x * ${strides[0]} + coords.z * ${strides[1]} + coords.y;
170
  }
171
`;
172
}
173

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.