onnxruntime

Форк
0
95 строк · 3.2 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
5
import { Graph } from '../../../graph';
6
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
7
import { Tensor } from '../../../tensor';
8
import { WebGLInferenceHandler } from '../inference-handler';
9
import { ProgramInfo, ProgramInfoLoader, TextureType } from '../types';
10

11
export interface LrnAttributes extends AttributeWithCacheKey {
12
  alpha: number;
13
  beta: number;
14
  bias: number;
15
  size: number;
16
}
17

18
export const lrn: OperatorImplementation<LrnAttributes> = (
19
  inferenceHandler: WebGLInferenceHandler,
20
  inputs: Tensor[],
21
  attributes: LrnAttributes,
22
): Tensor[] => {
23
  validateInputs(inputs);
24

25
  // if (inferenceHandler.session.pack) {
26
  //   return [inferenceHandler.run(createPackedLrnProgramInfoLoader(inferenceHandler, inputs, attributes),
27
  //   inputs)];
28
  // } else {
29
  return [inferenceHandler.run(createLrnProgramInfoLoader(inputs, attributes), inputs)];
30
  //}
31
};
32

33
export const parseLrnAttributes: OperatorInitialization<LrnAttributes> = (node: Graph.Node): LrnAttributes => {
34
  const alpha = node.attributes.getFloat('alpha', 0.0001);
35
  const beta = node.attributes.getFloat('beta', 0.75);
36
  const bias = node.attributes.getFloat('bias', 1.0);
37
  const size = node.attributes.getInt('size');
38

39
  return createAttributeWithCacheKey({ alpha, beta, bias, size });
40
};
41

42
const lrnProgramMetadata = {
43
  name: 'LRN',
44
  inputNames: ['X'],
45
  inputTypes: [TextureType.unpacked],
46
};
47

48
function createLrnProgramInfo(inputs: Tensor[], attributes: LrnAttributes): ProgramInfo {
49
  const C = inputs[0].dims[1];
50
  const rank = inputs[0].dims.length;
51
  const from = -Math.floor((attributes.size - 1) / 2);
52
  const to = Math.ceil((attributes.size - 1) / 2);
53
  const alpha = `float(${attributes.alpha}) / float(${attributes.size})`;
54
  const bias = `float(${attributes.bias})`;
55
  const beta = `float(${attributes.beta})`;
56

57
  const shaderSource = `
58
    float process(int indices[${rank}]) {
59
        int c = indices[1];
60
        float x = _X(indices);
61
        float square_sum = 0.0;
62

63
        for (int i = ${from}; i <= ${to}; i++) {
64
          int idx = c + i;
65
          if (c >= 0 && c < ${C}) {
66
            indices[1] = idx;
67
            float j = _X(indices);
68
            square_sum += j * j;
69
          }
70
        }
71
        return x / pow(${bias} + ${alpha} * square_sum, ${beta});
72
    }`;
73
  return {
74
    ...lrnProgramMetadata,
75
    cacheHint: attributes.cacheKey,
76
    output: { dims: inputs[0].dims, type: inputs[0].type, textureType: TextureType.unpacked },
77
    shaderSource,
78
  };
79
}
80

81
export function createLrnProgramInfoLoader(inputs: Tensor[], attributes: LrnAttributes): ProgramInfoLoader {
82
  return { ...lrnProgramMetadata, cacheHint: attributes.cacheKey, get: () => createLrnProgramInfo(inputs, attributes) };
83
}
84

85
const validateInputs = (inputs: Tensor[]): void => {
86
  if (!inputs || inputs.length !== 1) {
87
    throw new Error('LRN requires 1 input.');
88
  }
89
  if (inputs[0].dims.length !== 4) {
90
    throw new Error('currently only support LRN for input with "NCHW" format');
91
  }
92
  if (inputs[0].type !== 'float32') {
93
    throw new Error('input should be float type');
94
  }
95
};
96

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.