4
import { Graph } from '../../../graph';
5
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
6
import { Tensor } from '../../../tensor';
7
import { getGlsl } from '../glsl-source';
8
import { WebGLInferenceHandler } from '../inference-handler';
9
import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types';
11
export const instanceNormalization: OperatorImplementation<number> = (
12
inferenceHandler: WebGLInferenceHandler,
16
validateInputs(inputs);
18
const meanAndVariance = inferenceHandler.run(createMeanAndVarianceProgramInfoLoader(inputs[0]), inputs);
19
const output = inferenceHandler.run(
20
createComputeOutputProgramInfoLoader(inferenceHandler, inputs[0], epsilon, meanAndVariance.dims),
21
[inputs[0], meanAndVariance, inputs[1], inputs[2]],
26
export const parseInstanceNormalizationAttributes: OperatorInitialization<number> = (node: Graph.Node): number =>
27
node.attributes.getFloat('epsilon', 1e-5);
29
const meanAndVarianceProgramMetadata = {
30
name: 'InstanceNormalization_MeanAndVariance',
32
inputTypes: [TextureType.unpacked],
35
const createMeanAndVarianceProgramInfo = (metadata: ProgramMetadata, input: Tensor): ProgramInfo => {
36
const xDims = input.dims.slice();
37
const channel = xDims[1];
38
const channelSize = xDims[2] * xDims[3];
39
const outputShape = [xDims[0], channel];
41
const shaderSource = `
42
vec4 process(int[2] indices) {
48
for(int a2=0; a2<${xDims[2]}; a2++) {
50
for(int a3=0; a3<${xDims[3]}; a3++) {
56
float mean = temp / float(${channelSize});
58
for(int a2=0; a2<${xDims[2]}; a2++) {
60
for(int a3=0; a3<${xDims[3]}; a3++) {
63
temp += (x - mean) * (x - mean);
67
v.g = temp / float(${channelSize});
73
output: { dims: outputShape, type: input.type, textureType: TextureType.packedLastDimension },
78
const createMeanAndVarianceProgramInfoLoader = (input: Tensor): ProgramInfoLoader => ({
79
...meanAndVarianceProgramMetadata,
80
get: () => createMeanAndVarianceProgramInfo(meanAndVarianceProgramMetadata, input),
83
const computeOutputProgramMetadata = {
84
name: 'InstanceNormalization_ComputeOutput',
85
inputNames: ['X', 'MeanAndVariance', 'Scale', 'B'],
86
inputTypes: [TextureType.unpacked, TextureType.packedLastDimension, TextureType.unpacked, TextureType.unpacked],
89
const createComputeOutputProgramInfo = (
90
inferenceHandler: WebGLInferenceHandler,
91
metadata: ProgramMetadata,
94
meanAndVarianceShape: readonly number[],
96
const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
97
const [textureWidth, textureHeight] = inferenceHandler.calculateTextureWidthAndHeight(
99
TextureType.packedLastDimension,
101
const [meanAndVarianceWidth, meanAndVarianceHeight] = [textureWidth / 4, textureHeight];
102
const shaderSource = `
103
vec4 get_MeanAndVariance(int[2] mv) {
104
int offset = indicesToOffset_MeanAndVariance(mv);
105
vec2 coords = offsetToCoords(offset, ${meanAndVarianceWidth}, ${meanAndVarianceHeight});
106
return ${glsl.texture2D}(MeanAndVariance, coords);
109
float process(int[4] indices) {
113
vec4 mean_and_variance = get_MeanAndVariance(mv);
114
float mean = mean_and_variance.r;
115
float variance = mean_and_variance.g;
119
float scale = _Scale(sb);
122
return scale * (_X(indices) - mean) / sqrt(variance + epsilon) + b;
126
output: { dims: input.dims, type: input.type, textureType: TextureType.unpacked },
127
variables: [{ name: 'epsilon', type: 'float', data: epsilon }],
132
const createComputeOutputProgramInfoLoader = (
133
inferenceHandler: WebGLInferenceHandler,
136
meanAndVarianceShape: readonly number[],
137
): ProgramInfoLoader => {
138
const metadata = { ...computeOutputProgramMetadata, cacheHint: `${epsilon}` };
141
get: () => createComputeOutputProgramInfo(inferenceHandler, metadata, input, epsilon, meanAndVarianceShape),
145
const validateInputs = (inputs: Tensor[]): void => {
146
if (!inputs || inputs.length !== 3) {
147
throw new Error('InstanceNormalization requires 3 inputs.');
151
const scale = inputs[1];
156
if (X.dims.length < 3 || scale.dims.length !== 1 || B.dims.length !== 1) {
157
throw new Error('Invalid input shape.');
159
if (scale.dims[0] !== X.dims[1] || B.dims[0] !== X.dims[1]) {
160
throw new Error('Input shapes are mismatched.');
163
(X.type !== 'float32' && X.type !== 'float64') ||
164
(scale.type !== 'float32' && scale.type !== 'float64') ||
165
(B.type !== 'float32' && B.type !== 'float64')
167
throw new Error('Invalid input type.');
169
if (inputs[0].dims.length !== 4) {
170
throw new Error('Only support 4-D input shape.');