onnxruntime

Форк
0
/
instance-normalization.ts 
172 строки · 5.7 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { Graph } from '../../../graph';
5
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
6
import { Tensor } from '../../../tensor';
7
import { getGlsl } from '../glsl-source';
8
import { WebGLInferenceHandler } from '../inference-handler';
9
import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types';
10

11
export const instanceNormalization: OperatorImplementation<number> = (
12
  inferenceHandler: WebGLInferenceHandler,
13
  inputs: Tensor[],
14
  epsilon: number,
15
): Tensor[] => {
16
  validateInputs(inputs);
17

18
  const meanAndVariance = inferenceHandler.run(createMeanAndVarianceProgramInfoLoader(inputs[0]), inputs);
19
  const output = inferenceHandler.run(
20
    createComputeOutputProgramInfoLoader(inferenceHandler, inputs[0], epsilon, meanAndVariance.dims),
21
    [inputs[0], meanAndVariance, inputs[1], inputs[2]],
22
  );
23
  return [output];
24
};
25

26
export const parseInstanceNormalizationAttributes: OperatorInitialization<number> = (node: Graph.Node): number =>
27
  node.attributes.getFloat('epsilon', 1e-5);
28

29
const meanAndVarianceProgramMetadata = {
30
  name: 'InstanceNormalization_MeanAndVariance',
31
  inputNames: ['X'],
32
  inputTypes: [TextureType.unpacked],
33
};
34

35
const createMeanAndVarianceProgramInfo = (metadata: ProgramMetadata, input: Tensor): ProgramInfo => {
36
  const xDims = input.dims.slice();
37
  const channel = xDims[1];
38
  const channelSize = xDims[2] * xDims[3];
39
  const outputShape = [xDims[0], channel];
40

41
  const shaderSource = `
42
      vec4 process(int[2] indices) {
43
        vec4 v = vec4(0.0);
44
        int a[4];
45
        a[0] = indices[0];
46
        a[1] = indices[1];
47
        float temp = 0.0;
48
        for(int a2=0; a2<${xDims[2]}; a2++) {
49
          a[2] = a2;
50
          for(int a3=0; a3<${xDims[3]}; a3++) {
51
            a[3] = a3;
52
            float x = _X(a);
53
            temp += x;
54
          }
55
        }
56
        float mean = temp / float(${channelSize});
57
        temp = 0.0;
58
        for(int a2=0; a2<${xDims[2]}; a2++) {
59
          a[2] = a2;
60
          for(int a3=0; a3<${xDims[3]}; a3++) {
61
            a[3] = a3;
62
            float x = _X(a);
63
            temp += (x - mean) * (x - mean);
64
          }
65
        }
66
        v.r = mean;
67
        v.g = temp / float(${channelSize});
68

69
        return v;
70
      }`;
71
  return {
72
    ...metadata,
73
    output: { dims: outputShape, type: input.type, textureType: TextureType.packedLastDimension },
74
    shaderSource,
75
  };
76
};
77

78
const createMeanAndVarianceProgramInfoLoader = (input: Tensor): ProgramInfoLoader => ({
79
  ...meanAndVarianceProgramMetadata,
80
  get: () => createMeanAndVarianceProgramInfo(meanAndVarianceProgramMetadata, input),
81
});
82

83
const computeOutputProgramMetadata = {
84
  name: 'InstanceNormalization_ComputeOutput',
85
  inputNames: ['X', 'MeanAndVariance', 'Scale', 'B'],
86
  inputTypes: [TextureType.unpacked, TextureType.packedLastDimension, TextureType.unpacked, TextureType.unpacked],
87
};
88

89
const createComputeOutputProgramInfo = (
90
  inferenceHandler: WebGLInferenceHandler,
91
  metadata: ProgramMetadata,
92
  input: Tensor,
93
  epsilon: number,
94
  meanAndVarianceShape: readonly number[],
95
): ProgramInfo => {
96
  const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
97
  const [textureWidth, textureHeight] = inferenceHandler.calculateTextureWidthAndHeight(
98
    meanAndVarianceShape,
99
    TextureType.packedLastDimension,
100
  );
101
  const [meanAndVarianceWidth, meanAndVarianceHeight] = [textureWidth / 4, textureHeight];
102
  const shaderSource = `
103
      vec4 get_MeanAndVariance(int[2] mv) {
104
        int offset = indicesToOffset_MeanAndVariance(mv);
105
        vec2 coords = offsetToCoords(offset, ${meanAndVarianceWidth}, ${meanAndVarianceHeight});
106
        return ${glsl.texture2D}(MeanAndVariance, coords);
107
      }
108

109
      float process(int[4] indices) {
110
        int mv[2];
111
        mv[0] = indices[0];
112
        mv[1] = indices[1];
113
        vec4 mean_and_variance = get_MeanAndVariance(mv);
114
        float mean = mean_and_variance.r;
115
        float variance = mean_and_variance.g;
116

117
        int sb[1];
118
        sb[0] = indices[1];
119
        float scale = _Scale(sb);
120
        float b = _B(sb);
121

122
        return scale * (_X(indices) - mean) / sqrt(variance + epsilon) + b;
123
      }`;
124
  return {
125
    ...metadata,
126
    output: { dims: input.dims, type: input.type, textureType: TextureType.unpacked },
127
    variables: [{ name: 'epsilon', type: 'float', data: epsilon }],
128
    shaderSource,
129
  };
130
};
131

132
const createComputeOutputProgramInfoLoader = (
133
  inferenceHandler: WebGLInferenceHandler,
134
  input: Tensor,
135
  epsilon: number,
136
  meanAndVarianceShape: readonly number[],
137
): ProgramInfoLoader => {
138
  const metadata = { ...computeOutputProgramMetadata, cacheHint: `${epsilon}` };
139
  return {
140
    ...metadata,
141
    get: () => createComputeOutputProgramInfo(inferenceHandler, metadata, input, epsilon, meanAndVarianceShape),
142
  };
143
};
144

145
const validateInputs = (inputs: Tensor[]): void => {
146
  if (!inputs || inputs.length !== 3) {
147
    throw new Error('InstanceNormalization requires 3 inputs.');
148
  }
149

150
  const X = inputs[0];
151
  const scale = inputs[1];
152
  const B = inputs[2];
153

154
  // input should at least have three dimensions - N,C,dim1,...,dimn
155
  // other inputs can have only one dimensions
156
  if (X.dims.length < 3 || scale.dims.length !== 1 || B.dims.length !== 1) {
157
    throw new Error('Invalid input shape.');
158
  }
159
  if (scale.dims[0] !== X.dims[1] || B.dims[0] !== X.dims[1]) {
160
    throw new Error('Input shapes are mismatched.');
161
  }
162
  if (
163
    (X.type !== 'float32' && X.type !== 'float64') ||
164
    (scale.type !== 'float32' && scale.type !== 'float64') ||
165
    (B.type !== 'float32' && B.type !== 'float64')
166
  ) {
167
    throw new Error('Invalid input type.');
168
  }
169
  if (inputs[0].dims.length !== 4) {
170
    throw new Error('Only support 4-D input shape.');
171
  }
172
};
173

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.