1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types';
10
createTensorShapeVariables,
17
tensorTypeToWsglStorageType,
21
export interface InstanceNormAttributes {
23
format: 'NHWC' | 'NCHW';
26
const createInstanceNormProgramInfo = (
27
inputs: readonly TensorView[],
28
attributes: InstanceNormAttributes,
30
const xShape = inputs[0].dims;
31
const outputShape = xShape;
33
const normCount = ShapeUtil.sizeToDimension(xShape, axis);
34
const normSize = ShapeUtil.sizeFromDimension(xShape, axis);
35
const components = getMaxComponents(normSize);
36
const normPackedSize = normSize / components;
37
const inputShape = [xShape[0], xShape[1], normPackedSize];
38
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type'];
39
const programUniforms: ProgramUniform[] = [
40
{ type: DataType.uint32, data: normSize },
41
{ type: DataType.uint32, data: normPackedSize },
43
programUniforms.push(...createTensorShapeVariables(inputShape, inputShape));
45
const getShaderSource = (shaderHelper: ShaderHelper) => {
46
const x = inputVariable('x', inputs[0].dataType, inputShape.length, components);
47
const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims);
48
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims);
49
const output = outputVariable('output', inputs[0].dataType, inputShape.length, components);
50
const variables = [x, scale, bias, output];
51
const dataType = x.type.value;
52
const f32Type = components === 1 ? 'f32' : `vec${components}<f32>`;
53
const workgroupSize = 64;
55
const uniforms: UniformsArrayType = [
56
{ name: 'normSize', type: 'u32' },
57
{ name: 'normPackedSize', type: 'u32' },
60
var<workgroup> meanShared : f32;
61
var<workgroup> squaredNormShared : f32;
62
var<workgroup> workgroupShared : array<${f32Type}, ${workgroupSize}>;
63
const workgroupSize = ${workgroupSize}u;
64
${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)}
65
${shaderHelper.mainStart(workgroupSize)}
66
let norm = global_idx / workgroupSize;
67
let batch = norm / uniforms.x_shape[1];
68
let channel = norm % uniforms.x_shape[1];
69
let localIndex = local_id.x;
71
// initialize workgroup memory
72
var initial = ${f32Type}(0);
73
for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) {
74
initial = initial + ${f32Type}(${x.get('batch', 'channel', 'h')});
76
workgroupShared[localIndex] = initial;
79
// Calculate the mean of current channel data.
80
for (var currSize = workgroupSize >> 1; currSize > 0; currSize = currSize >> 1) {
81
if (localIndex < currSize) {
82
workgroupShared[localIndex] = workgroupShared[localIndex] + workgroupShared[localIndex + currSize];
86
if (localIndex == 0) {
87
meanShared = ${sumVector('workgroupShared[0]', components)} / f32(uniforms.normSize);
91
// reinitialize workgroup memory.
92
initial = ${f32Type}(0);
93
for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) {
94
let deviation = ${f32Type}(${x.get('batch', 'channel', 'h')}) - ${f32Type}(meanShared);
95
initial = initial + deviation * deviation;
97
workgroupShared[localIndex] = initial;
100
// Calculate the sum of square of deviation of current channel data.
101
for (var currSize = workgroupSize >> 1; currSize > 0; currSize = currSize >> 1) {
102
if (localIndex < currSize) {
103
workgroupShared[localIndex] = workgroupShared[localIndex] + workgroupShared[localIndex + currSize];
107
if (localIndex == 0) {
108
squaredNormShared = ${sumVector('workgroupShared[0]', components)};
112
let invStdDev = inverseSqrt(squaredNormShared / f32(uniforms.normSize) + f32(${attributes.epsilon}));
113
let channelScale = invStdDev * f32(${scale.getByOffset('channel')});
114
let channelShift = f32(${bias.getByOffset('channel')}) - meanShared * channelScale;
115
for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) {
116
let value = ${x.get('batch', 'channel', 'h')} * ${dataType}(${f32Type}(channelScale)) + ${dataType}(${
119
${output.set('batch', 'channel', 'h', 'value')};
124
...{ name: 'InstanceNormalization' },
125
// TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon.
126
shaderCache: { hint: `${attributes.epsilon};${components}`, inputDependencies },
128
outputs: [{ dims: outputShape, dataType: inputs[0].dataType }],
129
dispatchGroup: { x: normCount },
137
context: ComputeContext,
146
const components = getMaxComponents(c);
148
// we will store channel scale and channel shift in [2, components] matrix
149
// or in vec2 when components == 1
150
const outputType = components === 1 ? 'vec2f' : `mat2x${components}f`;
151
const sumCastType = components === 1 ? 'f32' : `vec${components}f`;
152
const setOutputValue = (var1: string, var2: string) => `${outputType}(${var1}, ${var2})`;
153
const unitsOfWork = (n * c) / components;
154
const wgSize = Math.ceil(h / WG);
156
const meanInputDependencies: ProgramInputTensorInfoDependency[] = ['type'];
157
const meanProgramUniforms: ProgramUniform[] = [
158
{ type: DataType.uint32, data: wgSize },
159
{ type: DataType.uint32, data: h },
160
{ type: DataType.uint32, data: Math.floor(c / components) },
161
{ type: DataType.uint32, data: Math.floor((h * c) / components) },
164
const getMeanShaderSource = (shaderHelper: ShaderHelper) => {
165
const inputHelper = inputVariable('input', input.dataType, input.dims, components);
167
${shaderHelper.declareVariables(inputHelper)}
168
@group(0) @binding(1) var<storage, read_write> output : array<${outputType}>;
169
struct Uniforms {wg_size:u32, H:u32, C:u32, image_size:u32};
170
@group(0) @binding(2) var<uniform> uniforms: Uniforms;
172
${shaderHelper.mainStart(WG)}
173
let currentImageNumber = global_idx / ${WG} / uniforms.C;
174
let currentChannelNumber = (global_idx / ${WG}) % uniforms.C;
175
let wgOffset = local_id.x * uniforms.wg_size;
176
if (wgOffset >= uniforms.H) {
179
let wgMax = min(wgOffset + uniforms.wg_size, uniforms.H);
181
let offset = currentImageNumber * uniforms.image_size + currentChannelNumber;
182
var sum = ${fillVector('f32', components)};
183
var squaredSum = ${fillVector('f32', components)};
184
for (var i: u32 = wgOffset; i < wgMax; i++) {
185
let value = ${sumCastType}(input[offset + i * uniforms.C]);
187
squaredSum += value * value;
189
output[global_idx] = ${setOutputValue('sum', 'squaredSum')};
193
const meanValues = context.compute(
195
name: 'InstanceNormComputeMean',
196
shaderCache: { hint: `${components}`, inputDependencies: meanInputDependencies },
198
outputs: [{ dims: [n, c, WG, 2], dataType: DataType.float }],
199
dispatchGroup: { x: (n * c) / components },
200
programUniforms: meanProgramUniforms,
202
getShaderSource: getMeanShaderSource,
204
{ inputs: [input], outputs: [-1] },
207
const programUniforms: ProgramUniform[] = [
208
{ type: DataType.uint32, data: unitsOfWork },
209
{ type: DataType.uint32, data: h },
210
{ type: DataType.uint32, data: Math.floor(c / components) },
211
{ type: DataType.uint32, data: Math.floor((WG * c) / components) },
213
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type', 'type'];
214
const getShaderSource = (shaderHelper: ShaderHelper) => {
215
const scaleHelper = inputVariable('scale', scale.dataType, scale.dims, components);
216
const biasHelper = inputVariable('bias', bias.dataType, bias.dims, components);
218
@group(0) @binding(0) var<storage, read> input : array<${outputType}>;
219
@group(0) @binding(1) var<storage, read> scale : array<${scaleHelper.type.storage}>;
220
@group(0) @binding(2) var<storage, read> bias : array<${biasHelper.type.storage}>;
221
@group(0) @binding(3) var<storage, read_write> output : array<${outputType}>;
222
struct Uniforms {units_of_work : u32, H: u32, C : u32, image_size : u32};
223
@group(0) @binding(4) var<uniform> uniforms: Uniforms;
225
${shaderHelper.mainStart()}
226
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.units_of_work')}
227
let currentImageNumber = global_idx / uniforms.C;
228
let currentChannelNumber = global_idx % uniforms.C;
230
let offset = currentImageNumber * uniforms.image_size;
231
var sum = ${fillVector('f32', components)};
232
var squaredSum = ${fillVector('f32', components)};
233
for (var i: u32 = 0; i < min(${WG}, uniforms.H); i++) {
234
let value = input[offset + i + currentChannelNumber * ${WG}];
236
squaredSum += value[1];
238
sum = sum / f32(uniforms.H);
239
squaredSum = squaredSum / f32(uniforms.H);
240
let invStdDev = inverseSqrt(squaredSum - sum * sum + f32(${epsilon}));
241
let channelScale = invStdDev * ${sumCastType}(scale[currentChannelNumber]);
242
let channelShift = ${sumCastType}(bias[currentChannelNumber]) - sum * channelScale;
244
output[global_idx] = ${setOutputValue('channelScale', 'channelShift')};
247
return context.compute(
249
name: 'InstanceNormComputeChannelScaleShift',
250
// TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon.
251
shaderCache: { hint: `${components};${epsilon}`, inputDependencies },
253
outputs: [{ dims: [n, c, 2], dataType: DataType.float }],
254
dispatchGroup: { x: Math.ceil(unitsOfWork / 64 /* workgroup size */) },
259
{ inputs: [meanValues, scale, bias], outputs: [-1] },
263
const createInstanceNormNHWCProgramInfo = (
264
context: ComputeContext,
265
inputs: readonly TensorView[],
266
attributes: InstanceNormAttributes,
268
const xShape = inputs[0].dims;
269
const outputShape = xShape;
271
const C = xShape[xShape.length - 1];
272
const H = ShapeUtil.sizeFromDimension(xShape, 1) / C;
273
const components = getMaxComponents(C);
274
const outputSize = ShapeUtil.size(outputShape) / components;
275
const programUniforms: ProgramUniform[] = [
276
{ type: DataType.uint32, data: H },
277
{ type: DataType.uint32, data: Math.floor(C / components) },
279
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
280
// first compute mean
281
const channelScaleShift = computeMean(context, inputs[0], inputs[1], inputs[2], N, H, C, attributes.epsilon);
282
const getShaderSource = (shaderHelper: ShaderHelper) => {
283
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
284
const scaleType = components === 1 ? 'vec2f' : `mat2x${components}f`;
285
const scaleCastType = components === 1 ? dataType : `vec${components}<${dataType}>`;
287
const inputHelper = inputVariable('input', inputs[0].dataType, inputs[0].dims, components);
288
const outputHelper = outputVariable('output', inputs[0].dataType, outputShape, components);
291
@group(0) @binding(0) var<storage, read> input : array<${inputHelper.type.storage}>;
292
@group(0) @binding(1) var<storage, read> scaleInput : array<${scaleType}>;
293
@group(0) @binding(2) var<storage, read_write> output : array<${outputHelper.type.storage}>;
294
struct Uniforms {H: u32, C : u32};
295
@group(0) @binding(3) var<uniform> uniforms: Uniforms;
297
${shaderHelper.mainStart()}
298
let currentImageNumber = global_idx / (uniforms.C * uniforms.H);
299
let currentChannelNumber = global_idx % uniforms.C;
301
let scaleOffset = currentImageNumber * uniforms.C + currentChannelNumber;
302
let scale = scaleInput[scaleOffset];
303
output[global_idx] = fma(input[global_idx], ${scaleCastType}(scale[0]), ${scaleCastType}(scale[1]));
308
name: 'InstanceNormalizationNHWC',
309
shaderCache: { hint: `${components}`, inputDependencies },
311
outputs: [{ dims: outputShape, dataType: inputs[0].dataType }],
312
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
317
{ inputs: [inputs[0], channelScaleShift] },
321
export const instanceNorm = (context: ComputeContext, attributes: InstanceNormAttributes): void => {
322
if (attributes.format === 'NHWC') {
323
createInstanceNormNHWCProgramInfo(context, context.inputs, attributes);
325
context.compute(createInstanceNormProgramInfo(context.inputs, attributes));