onnxruntime

Форк
0
71 строка · 2.1 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { getGlsl } from '../glsl-source';
5
import { WebGLInferenceHandler } from '../inference-handler';
6
import { TextureData, TextureType } from '../types';
7

8
export const encodeAsUint8 = (inferenceHandler: WebGLInferenceHandler, input: TextureData): TextureData => {
9
  const outputShape = input.shape;
10
  const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
11
  /**
12
   * https://github.com/tensorflow/tfjs-core/blob/master/src/kernels/webgl/encode_float_gpu.ts
13
   */
14
  const shaderSource = `
15
    const float FLOAT_MAX = 1.70141184e38;
16
    const float FLOAT_MIN = 1.17549435e-38;
17

18
    bool isNaN(float val) {
19
      return (val < 1.0 || 0.0 < val || val == 0.0) ? false : true;
20
    }
21

22
    highp vec4 encodeAsUint8(highp float v) {
23
      if (isNaN(v)) {
24
        return vec4(255, 255, 255, 255);
25
      }
26

27
      highp float av = abs(v);
28

29
      if(av < FLOAT_MIN) {
30
        return vec4(0.0, 0.0, 0.0, 0.0);
31
      } else if(v > FLOAT_MAX) {
32
        return vec4(0.0, 0.0, 128.0, 127.0) / 255.0;
33
      } else if(v < -FLOAT_MAX) {
34
        return vec4(0.0, 0.0,  128.0, 255.0) / 255.0;
35
      }
36

37
      highp vec4 c = vec4(0,0,0,0);
38

39
      highp float e = floor(log2(av));
40
      highp float m = exp2(fract(log2(av))) - 1.0;
41

42
      c[2] = floor(128.0 * m);
43
      m -= c[2] / 128.0;
44
      c[1] = floor(32768.0 * m);
45
      m -= c[1] / 32768.0;
46
      c[0] = floor(8388608.0 * m);
47

48
      highp float ebias = e + 127.0;
49
      c[3] = floor(ebias / 2.0);
50
      ebias -= c[3] * 2.0;
51
      c[2] += floor(ebias) * 128.0;
52

53
      c[3] += 128.0 * step(0.0, -v);
54

55
      return c / 255.0;
56
    }
57

58
    void main() {
59
      float value = ${glsl.texture2D}(X,TexCoords).r;
60
      ${glsl.output} = encodeAsUint8(value);
61
    }`;
62
  const programInfo = {
63
    name: 'Uint8Encode',
64
    inputTypes: [TextureType.unpacked],
65
    inputNames: ['X'],
66
    output: { dims: outputShape, type: input.tensor.type, textureType: TextureType.downloadUint8AsFloat },
67
    shaderSource,
68
    hasMain: true,
69
  };
70
  return inferenceHandler.executeProgram(programInfo, [input.tensor]);
71
};
72

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.