onnxruntime

Форк
0
/
batch-normalization.ts 
123 строки · 4.1 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
5
import { Graph } from '../../../graph';
6
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
7
import { Tensor } from '../../../tensor';
8
import { getGlsl } from '../glsl-source';
9
import { WebGLInferenceHandler } from '../inference-handler';
10
import { ProgramInfo, TextureType } from '../types';
11

12
export interface BatchNormalizationAttributes extends AttributeWithCacheKey {
13
  epsilon: number;
14
  momentum: number;
15
  spatial: number;
16
}
17

18
const batchNormalizationProgramMetadata = {
19
  name: 'BatchNormalization',
20
  inputNames: ['A', 'Scale', 'B', 'Mean', 'Variance'],
21
  inputTypes: [
22
    TextureType.unpacked,
23
    TextureType.unpacked,
24
    TextureType.unpacked,
25
    TextureType.unpacked,
26
    TextureType.unpacked,
27
  ],
28
};
29

30
export const batchNormalization: OperatorImplementation<BatchNormalizationAttributes> = (
31
  inferenceHandler: WebGLInferenceHandler,
32
  inputs: Tensor[],
33
  attributes: BatchNormalizationAttributes,
34
): Tensor[] => {
35
  validateInputs(inputs);
36
  const output = inferenceHandler.run(
37
    {
38
      ...batchNormalizationProgramMetadata,
39
      cacheHint: attributes.cacheKey,
40
      get: () => createBatchNormalizationProgramInfo(inferenceHandler, inputs, attributes),
41
    },
42
    inputs,
43
  );
44
  return [output];
45
};
46

47
export const parseBatchNormalizationAttributes: OperatorInitialization<BatchNormalizationAttributes> = (
48
  node: Graph.Node,
49
): BatchNormalizationAttributes => {
50
  const epsilon = node.attributes.getFloat('epsilon', 1e-5);
51
  const momentum = node.attributes.getFloat('momentum', 0.9);
52
  const spatial = node.attributes.getInt('spatial', 1);
53
  return createAttributeWithCacheKey({ epsilon, momentum, spatial });
54
};
55

56
const createBatchNormalizationProgramInfo = (
57
  inferenceHandler: WebGLInferenceHandler,
58
  inputs: Tensor[],
59
  attributes: BatchNormalizationAttributes,
60
): ProgramInfo => {
61
  const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
62
  const rank = inputs[0].dims.length;
63
  const [scaleWidth, scaleHeight] = inferenceHandler.calculateTextureWidthAndHeight(
64
    inputs[1].dims,
65
    TextureType.unpacked,
66
  );
67
  const shaderSource = `
68
  float process(int[${rank}] indices) {
69
    vec2 position = offsetToCoords(indices[1], ${scaleWidth}, ${scaleHeight});
70
    float scale = getColorAsFloat(${glsl.texture2D}(Scale, position));
71
    float mean = getColorAsFloat(${glsl.texture2D}(Mean, position));
72
    float variance = getColorAsFloat(${glsl.texture2D}(Variance, position));
73
    float b = getColorAsFloat(${glsl.texture2D}(B, position));
74

75
    return scale * ( (_A(indices) - mean) / sqrt(variance + float(${attributes.epsilon})) ) + b;
76
  }`;
77
  return {
78
    ...batchNormalizationProgramMetadata,
79
    output: { dims: inputs[0].dims, type: inputs[0].type, textureType: TextureType.unpacked },
80
    shaderSource,
81
  };
82
};
83

84
const validateInputs = (inputs: Tensor[]): void => {
85
  if (!inputs || inputs.length !== 5) {
86
    throw new Error('BatchNormalization requires 5 inputs.');
87
  }
88

89
  const X = inputs[0];
90
  const scale = inputs[1];
91
  const B = inputs[2];
92
  const mean = inputs[3];
93
  const var_ = inputs[4];
94

95
  // input should atleast have three dimensions - N,C,dim1,...,dimn
96
  // other inputs can have only one dimensions
97
  if (
98
    X.dims.length < 3 ||
99
    scale.dims.length !== 1 ||
100
    B.dims.length !== 1 ||
101
    mean.dims.length !== 1 ||
102
    var_.dims.length !== 1
103
  ) {
104
    throw new Error('invalid input shape.');
105
  }
106
  if (
107
    scale.dims[0] !== X.dims[1] ||
108
    B.dims[0] !== X.dims[1] ||
109
    mean.dims[0] !== X.dims[1] ||
110
    var_.dims[0] !== X.dims[1]
111
  ) {
112
    throw new Error('invalid input shape.');
113
  }
114
  if (
115
    (X.type !== 'float32' && X.type !== 'float64') ||
116
    (scale.type !== 'float32' && scale.type !== 'float64') ||
117
    (B.type !== 'float32' && B.type !== 'float64') ||
118
    (mean.type !== 'float32' && mean.type !== 'float64') ||
119
    (var_.type !== 'float32' && var_.type !== 'float64')
120
  ) {
121
    throw new Error('invalid input tensor types.');
122
  }
123
};
124

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.