onnxruntime

Форк
0
94 строки · 3.4 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { DataType } from '../../../wasm-common';
5
import { MAX_CLIP, MIN_CLIP } from '../../util';
6
import { ProgramUniform } from '../types';
7

8
import { UniformsArrayType } from './common';
9

10
export interface InternalActivationAttributes {
11
  readonly activation: string;
12
  readonly clipMin?: number;
13
  readonly clipMax?: number;
14
  readonly alpha?: number;
15
  readonly beta?: number;
16
}
17

18
export const getActivationSnippet = (
19
  attributes: InternalActivationAttributes,
20
  valueType: string,
21
  baseType = 'f32',
22
): string => {
23
  switch (attributes.activation) {
24
    case 'Relu':
25
      return `value = max(value, ${valueType}(0.0));`;
26
    case 'Sigmoid':
27
      return `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));`;
28
    case 'Clip':
29
      return `value = clamp(value, ${valueType}(${baseType}(uniforms.clip_min)), ${valueType}(${
30
        baseType
31
      }(uniforms.clip_max)));`;
32
    case 'HardSigmoid':
33
      return `value = max(${valueType}(0.0), min(${valueType}(1.0), ${baseType}(uniforms.alpha) * value + ${
34
        baseType
35
      }(uniforms.beta)));`;
36
    case 'LeakyRelu':
37
      return `value = select(${baseType}(uniforms.alpha) * value, value, value >= ${valueType}(0.0));`;
38
    case 'Tanh':
39
      return `let e2x = exp(-2.0 * abs(value));
40
              value = sign(value) * (1.0 - e2x) / (1.0 + e2x);
41
        `;
42
    case '':
43
      return '';
44
    // TODO: adding other activations that can be fused.
45
    default:
46
      throw new Error(`Unsupported activation ${attributes.activation}`);
47
  }
48
};
49

50
export const appendActivationUniformsData = (
51
  attributes: InternalActivationAttributes,
52
  programUniform: ProgramUniform[],
53
) => {
54
  if (attributes.activation === 'Clip') {
55
    programUniform.push(
56
      { type: DataType.float, data: attributes.clipMax! },
57
      { type: DataType.float, data: attributes.clipMin! },
58
    );
59
  } else if (attributes.activation === 'HardSigmoid') {
60
    programUniform.push(
61
      { type: DataType.float, data: attributes.alpha! },
62
      { type: DataType.float, data: attributes.beta! },
63
    );
64
  } else if (attributes.activation === 'LeakyRelu') {
65
    programUniform.push({ type: DataType.float, data: attributes.alpha! });
66
  }
67
};
68

69
export const appendActivationUniforms = (attributes: InternalActivationAttributes, uniforms: UniformsArrayType) => {
70
  if (attributes.activation === 'Clip') {
71
    uniforms.push({ name: 'clip_max', type: 'f32' }, { name: 'clip_min', type: 'f32' });
72
  } else if (attributes.activation === 'HardSigmoid') {
73
    uniforms.push({ name: 'alpha', type: 'f32' }, { name: 'beta', type: 'f32' });
74
  } else if (attributes.activation === 'LeakyRelu') {
75
    uniforms.push({ name: 'alpha', type: 'f32' });
76
  }
77
};
78

79
export const parseInternalActivationAttributes = (
80
  attributes: Record<string, unknown> | undefined,
81
): InternalActivationAttributes => {
82
  const activation = (attributes?.activation as string) || '';
83
  if (activation === 'HardSigmoid') {
84
    const [alpha, beta] = (attributes?.activation_params as [number, number]) || [0.2, 0.5];
85
    return { activation, alpha, beta };
86
  } else if (activation === 'Clip') {
87
    const [clipMin, clipMax] = (attributes?.activation_params as [number, number]) || [MIN_CLIP, MAX_CLIP];
88
    return { activation, clipMax, clipMin };
89
  } else if (activation === 'LeakyRelu') {
90
    const [alpha] = (attributes?.activation_params as [number]) || [0.01];
91
    return { activation, alpha };
92
  }
93
  return { activation };
94
};
95

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.