1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
5
import { Graph } from '../../../graph';
6
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
7
import { Tensor } from '../../../tensor';
8
import { getGlsl } from '../glsl-source';
9
import { WebGLInferenceHandler } from '../inference-handler';
10
import { ProgramInfo, TextureType } from '../types';
12
export interface BatchNormalizationAttributes extends AttributeWithCacheKey {
18
const batchNormalizationProgramMetadata = {
19
name: 'BatchNormalization',
20
inputNames: ['A', 'Scale', 'B', 'Mean', 'Variance'],
30
export const batchNormalization: OperatorImplementation<BatchNormalizationAttributes> = (
31
inferenceHandler: WebGLInferenceHandler,
33
attributes: BatchNormalizationAttributes,
35
validateInputs(inputs);
36
const output = inferenceHandler.run(
38
...batchNormalizationProgramMetadata,
39
cacheHint: attributes.cacheKey,
40
get: () => createBatchNormalizationProgramInfo(inferenceHandler, inputs, attributes),
47
export const parseBatchNormalizationAttributes: OperatorInitialization<BatchNormalizationAttributes> = (
49
): BatchNormalizationAttributes => {
50
const epsilon = node.attributes.getFloat('epsilon', 1e-5);
51
const momentum = node.attributes.getFloat('momentum', 0.9);
52
const spatial = node.attributes.getInt('spatial', 1);
53
return createAttributeWithCacheKey({ epsilon, momentum, spatial });
56
const createBatchNormalizationProgramInfo = (
57
inferenceHandler: WebGLInferenceHandler,
59
attributes: BatchNormalizationAttributes,
61
const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
62
const rank = inputs[0].dims.length;
63
const [scaleWidth, scaleHeight] = inferenceHandler.calculateTextureWidthAndHeight(
67
const shaderSource = `
68
float process(int[${rank}] indices) {
69
vec2 position = offsetToCoords(indices[1], ${scaleWidth}, ${scaleHeight});
70
float scale = getColorAsFloat(${glsl.texture2D}(Scale, position));
71
float mean = getColorAsFloat(${glsl.texture2D}(Mean, position));
72
float variance = getColorAsFloat(${glsl.texture2D}(Variance, position));
73
float b = getColorAsFloat(${glsl.texture2D}(B, position));
75
return scale * ( (_A(indices) - mean) / sqrt(variance + float(${attributes.epsilon})) ) + b;
78
...batchNormalizationProgramMetadata,
79
output: { dims: inputs[0].dims, type: inputs[0].type, textureType: TextureType.unpacked },
84
const validateInputs = (inputs: Tensor[]): void => {
85
if (!inputs || inputs.length !== 5) {
86
throw new Error('BatchNormalization requires 5 inputs.');
90
const scale = inputs[1];
92
const mean = inputs[3];
93
const var_ = inputs[4];
95
// input should atleast have three dimensions - N,C,dim1,...,dimn
96
// other inputs can have only one dimensions
99
scale.dims.length !== 1 ||
100
B.dims.length !== 1 ||
101
mean.dims.length !== 1 ||
102
var_.dims.length !== 1
104
throw new Error('invalid input shape.');
107
scale.dims[0] !== X.dims[1] ||
108
B.dims[0] !== X.dims[1] ||
109
mean.dims[0] !== X.dims[1] ||
110
var_.dims[0] !== X.dims[1]
112
throw new Error('invalid input shape.');
115
(X.type !== 'float32' && X.type !== 'float64') ||
116
(scale.type !== 'float32' && scale.type !== 'float64') ||
117
(B.type !== 'float32' && B.type !== 'float64') ||
118
(mean.type !== 'float32' && mean.type !== 'float64') ||
119
(var_.type !== 'float32' && var_.type !== 'float64')
121
throw new Error('invalid input tensor types.');