onnxruntime

Форк
0
67 строк · 2.0 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { Tensor } from '../../../tensor';
5
import { getGlsl } from '../glsl-source';
6
import { WebGLInferenceHandler } from '../inference-handler';
7
import { ProgramInfo, ProgramInfoLoader, TextureType } from '../types';
8
import { getCoordsDataType } from '../utils';
9

10
import { getChannels, unpackFromChannel } from './packing-utils';
11

12
const unpackProgramMetadata = {
13
  name: 'unpack',
14
  inputNames: ['A'],
15
  inputTypes: [TextureType.packed],
16
};
17

18
export const createUnpackProgramInfo = (handler: WebGLInferenceHandler, input: Tensor): ProgramInfo => {
19
  const rank = input.dims.length;
20

21
  const channels = getChannels('rc', rank);
22
  const innerDims = channels.slice(-2);
23
  const coordsDataType = getCoordsDataType(rank);
24
  const unpackChannel = unpackFromChannel();
25
  const isScalar = input.dims.length === 0;
26
  const sourceCoords = isScalar ? '' : getSourceCoords(rank, channels);
27
  const coords = rank <= 1 ? 'rc' : `vec2(${innerDims.join(',')})`;
28
  const glsl = getGlsl(handler.session.backend.glContext.version);
29
  const shaderSource = `
30
    ${unpackChannel}
31
    void main() {
32
      ${coordsDataType} rc = getOutputCoords();
33

34
       // Sample the texture with the coords to get the rgba channel value.
35
       vec4 packedInput = getA(${sourceCoords});
36

37
       ${glsl.output} = vec4(getChannel(packedInput, ${coords}), 0, 0, 0);
38
     }
39
   `;
40

41
  return {
42
    ...unpackProgramMetadata,
43
    hasMain: true,
44
    output: { dims: input.dims, type: input.type, textureType: TextureType.unpacked },
45
    shaderSource,
46
  };
47
};
48

49
export const createUnpackProgramInfoLoader = (handler: WebGLInferenceHandler, input: Tensor): ProgramInfoLoader => ({
50
  ...unpackProgramMetadata,
51
  get: () => createUnpackProgramInfo(handler, input),
52
});
53

54
function getSourceCoords(rank: number, dims: string[]): string {
55
  if (rank === 1) {
56
    return 'rc';
57
  }
58

59
  let coords = '';
60
  for (let i = 0; i < rank; i++) {
61
    coords += dims[i];
62
    if (i < rank - 1) {
63
      coords += ',';
64
    }
65
  }
66
  return coords;
67
}
68

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.