onnxruntime

Форк
0
/
instance-norm.ts 
327 строк · 13.1 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types';
8

9
import {
10
  createTensorShapeVariables,
11
  fillVector,
12
  getMaxComponents,
13
  inputVariable,
14
  outputVariable,
15
  ShaderHelper,
16
  sumVector,
17
  tensorTypeToWsglStorageType,
18
  UniformsArrayType,
19
} from './common';
20

21
export interface InstanceNormAttributes {
22
  epsilon: number;
23
  format: 'NHWC' | 'NCHW';
24
}
25

26
const createInstanceNormProgramInfo = (
27
  inputs: readonly TensorView[],
28
  attributes: InstanceNormAttributes,
29
): ProgramInfo => {
30
  const xShape = inputs[0].dims;
31
  const outputShape = xShape;
32
  const axis = 2;
33
  const normCount = ShapeUtil.sizeToDimension(xShape, axis);
34
  const normSize = ShapeUtil.sizeFromDimension(xShape, axis);
35
  const components = getMaxComponents(normSize);
36
  const normPackedSize = normSize / components;
37
  const inputShape = [xShape[0], xShape[1], normPackedSize];
38
  const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type'];
39
  const programUniforms: ProgramUniform[] = [
40
    { type: DataType.uint32, data: normSize },
41
    { type: DataType.uint32, data: normPackedSize },
42
  ];
43
  programUniforms.push(...createTensorShapeVariables(inputShape, inputShape));
44

45
  const getShaderSource = (shaderHelper: ShaderHelper) => {
46
    const x = inputVariable('x', inputs[0].dataType, inputShape.length, components);
47
    const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims);
48
    const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims);
49
    const output = outputVariable('output', inputs[0].dataType, inputShape.length, components);
50
    const variables = [x, scale, bias, output];
51
    const dataType = x.type.value;
52
    const f32Type = components === 1 ? 'f32' : `vec${components}<f32>`;
53
    const workgroupSize = 64;
54

55
    const uniforms: UniformsArrayType = [
56
      { name: 'normSize', type: 'u32' },
57
      { name: 'normPackedSize', type: 'u32' },
58
    ];
59
    return `
60
  var<workgroup> meanShared : f32;
61
  var<workgroup> squaredNormShared : f32;
62
  var<workgroup> workgroupShared : array<${f32Type}, ${workgroupSize}>;
63
  const workgroupSize = ${workgroupSize}u;
64
  ${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)}
65
  ${shaderHelper.mainStart(workgroupSize)}
66
    let norm = global_idx / workgroupSize;
67
    let batch = norm / uniforms.x_shape[1];
68
    let channel = norm % uniforms.x_shape[1];
69
    let localIndex = local_id.x;
70

71
    // initialize workgroup memory
72
    var initial = ${f32Type}(0);
73
    for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) {
74
      initial = initial + ${f32Type}(${x.get('batch', 'channel', 'h')});
75
    }
76
    workgroupShared[localIndex] = initial;
77
    workgroupBarrier();
78

79
    // Calculate the mean of current channel data.
80
    for (var currSize = workgroupSize >> 1;  currSize > 0; currSize = currSize >> 1) {
81
      if (localIndex < currSize) {
82
        workgroupShared[localIndex] = workgroupShared[localIndex] + workgroupShared[localIndex + currSize];
83
      }
84
      workgroupBarrier();
85
    }
86
    if (localIndex == 0) {
87
      meanShared = ${sumVector('workgroupShared[0]', components)} / f32(uniforms.normSize);
88
    }
89
    workgroupBarrier();
90

91
    // reinitialize workgroup memory.
92
    initial = ${f32Type}(0);
93
    for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) {
94
      let deviation =  ${f32Type}(${x.get('batch', 'channel', 'h')}) - ${f32Type}(meanShared);
95
      initial = initial + deviation * deviation;
96
    }
97
    workgroupShared[localIndex] = initial;
98
    workgroupBarrier();
99

100
    // Calculate the sum of square of deviation of current channel data.
101
    for (var currSize = workgroupSize >> 1;  currSize > 0; currSize = currSize >> 1) {
102
      if (localIndex < currSize) {
103
        workgroupShared[localIndex] = workgroupShared[localIndex] + workgroupShared[localIndex + currSize];
104
      }
105
      workgroupBarrier();
106
    }
107
    if (localIndex == 0) {
108
      squaredNormShared = ${sumVector('workgroupShared[0]', components)};
109
    }
110
    workgroupBarrier();
111

112
    let invStdDev = inverseSqrt(squaredNormShared / f32(uniforms.normSize) + f32(${attributes.epsilon}));
113
    let channelScale = invStdDev * f32(${scale.getByOffset('channel')});
114
    let channelShift = f32(${bias.getByOffset('channel')}) - meanShared * channelScale;
115
    for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) {
116
      let value = ${x.get('batch', 'channel', 'h')} * ${dataType}(${f32Type}(channelScale)) + ${dataType}(${
117
        f32Type
118
      }(channelShift));
119
      ${output.set('batch', 'channel', 'h', 'value')};
120
    }
121
  }`;
122
  };
123
  return {
124
    ...{ name: 'InstanceNormalization' },
125
    // TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon.
126
    shaderCache: { hint: `${attributes.epsilon};${components}`, inputDependencies },
127
    getRunData: () => ({
128
      outputs: [{ dims: outputShape, dataType: inputs[0].dataType }],
129
      dispatchGroup: { x: normCount },
130
      programUniforms,
131
    }),
132
    getShaderSource,
133
  };
134
};
135

136
const computeMean = (
137
  context: ComputeContext,
138
  input: TensorView,
139
  scale: TensorView,
140
  bias: TensorView,
141
  n: number,
142
  h: number,
143
  c: number,
144
  epsilon: number,
145
) => {
146
  const components = getMaxComponents(c);
147
  const WG = 64;
148
  // we will store channel scale and channel shift in [2, components] matrix
149
  // or in vec2 when components == 1
150
  const outputType = components === 1 ? 'vec2f' : `mat2x${components}f`;
151
  const sumCastType = components === 1 ? 'f32' : `vec${components}f`;
152
  const setOutputValue = (var1: string, var2: string) => `${outputType}(${var1}, ${var2})`;
153
  const unitsOfWork = (n * c) / components;
154
  const wgSize = Math.ceil(h / WG);
155

156
  const meanInputDependencies: ProgramInputTensorInfoDependency[] = ['type'];
157
  const meanProgramUniforms: ProgramUniform[] = [
158
    { type: DataType.uint32, data: wgSize },
159
    { type: DataType.uint32, data: h },
160
    { type: DataType.uint32, data: Math.floor(c / components) },
161
    { type: DataType.uint32, data: Math.floor((h * c) / components) },
162
  ];
163

164
  const getMeanShaderSource = (shaderHelper: ShaderHelper) => {
165
    const inputHelper = inputVariable('input', input.dataType, input.dims, components);
166
    return `
167
  ${shaderHelper.declareVariables(inputHelper)}
168
  @group(0) @binding(1) var<storage, read_write> output : array<${outputType}>;
169
  struct Uniforms {wg_size:u32, H:u32, C:u32, image_size:u32};
170
  @group(0) @binding(2) var<uniform> uniforms: Uniforms;
171

172
  ${shaderHelper.mainStart(WG)}
173
    let currentImageNumber = global_idx / ${WG} / uniforms.C;
174
    let currentChannelNumber = (global_idx / ${WG}) % uniforms.C;
175
    let wgOffset = local_id.x * uniforms.wg_size;
176
    if (wgOffset >= uniforms.H) {
177
        return;
178
    }
179
    let wgMax = min(wgOffset + uniforms.wg_size, uniforms.H);
180

181
    let offset = currentImageNumber * uniforms.image_size + currentChannelNumber;
182
    var sum = ${fillVector('f32', components)};
183
    var squaredSum = ${fillVector('f32', components)};
184
    for (var i: u32 = wgOffset; i < wgMax; i++) {
185
        let value = ${sumCastType}(input[offset + i * uniforms.C]);
186
        sum += value;
187
        squaredSum += value * value;
188
    }
189
    output[global_idx] = ${setOutputValue('sum', 'squaredSum')};
190
  }`;
191
  };
192

193
  const meanValues = context.compute(
194
    {
195
      name: 'InstanceNormComputeMean',
196
      shaderCache: { hint: `${components}`, inputDependencies: meanInputDependencies },
197
      getRunData: () => ({
198
        outputs: [{ dims: [n, c, WG, 2], dataType: DataType.float }],
199
        dispatchGroup: { x: (n * c) / components },
200
        programUniforms: meanProgramUniforms,
201
      }),
202
      getShaderSource: getMeanShaderSource,
203
    },
204
    { inputs: [input], outputs: [-1] },
205
  )[0];
206

207
  const programUniforms: ProgramUniform[] = [
208
    { type: DataType.uint32, data: unitsOfWork },
209
    { type: DataType.uint32, data: h },
210
    { type: DataType.uint32, data: Math.floor(c / components) },
211
    { type: DataType.uint32, data: Math.floor((WG * c) / components) },
212
  ];
213
  const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type', 'type'];
214
  const getShaderSource = (shaderHelper: ShaderHelper) => {
215
    const scaleHelper = inputVariable('scale', scale.dataType, scale.dims, components);
216
    const biasHelper = inputVariable('bias', bias.dataType, bias.dims, components);
217
    return `
218
  @group(0) @binding(0) var<storage, read> input : array<${outputType}>;
219
  @group(0) @binding(1) var<storage, read> scale : array<${scaleHelper.type.storage}>;
220
  @group(0) @binding(2) var<storage, read> bias : array<${biasHelper.type.storage}>;
221
  @group(0) @binding(3) var<storage, read_write> output : array<${outputType}>;
222
  struct Uniforms {units_of_work : u32, H: u32, C : u32, image_size : u32};
223
  @group(0) @binding(4) var<uniform> uniforms: Uniforms;
224

225
  ${shaderHelper.mainStart()}
226
    ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.units_of_work')}
227
    let currentImageNumber = global_idx / uniforms.C;
228
    let currentChannelNumber = global_idx % uniforms.C;
229

230
    let offset = currentImageNumber * uniforms.image_size;
231
    var sum = ${fillVector('f32', components)};
232
    var squaredSum = ${fillVector('f32', components)};
233
    for (var i: u32 = 0; i < min(${WG}, uniforms.H); i++) {
234
        let value = input[offset + i + currentChannelNumber * ${WG}];
235
        sum += value[0];
236
        squaredSum += value[1];
237
    }
238
    sum = sum / f32(uniforms.H);
239
    squaredSum = squaredSum / f32(uniforms.H);
240
    let invStdDev = inverseSqrt(squaredSum - sum * sum + f32(${epsilon}));
241
    let channelScale = invStdDev * ${sumCastType}(scale[currentChannelNumber]);
242
    let channelShift = ${sumCastType}(bias[currentChannelNumber]) - sum * channelScale;
243

244
    output[global_idx] = ${setOutputValue('channelScale', 'channelShift')};
245
  }`;
246
  };
247
  return context.compute(
248
    {
249
      name: 'InstanceNormComputeChannelScaleShift',
250
      // TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon.
251
      shaderCache: { hint: `${components};${epsilon}`, inputDependencies },
252
      getRunData: () => ({
253
        outputs: [{ dims: [n, c, 2], dataType: DataType.float }],
254
        dispatchGroup: { x: Math.ceil(unitsOfWork / 64 /* workgroup size */) },
255
        programUniforms,
256
      }),
257
      getShaderSource,
258
    },
259
    { inputs: [meanValues, scale, bias], outputs: [-1] },
260
  )[0];
261
};
262

263
const createInstanceNormNHWCProgramInfo = (
264
  context: ComputeContext,
265
  inputs: readonly TensorView[],
266
  attributes: InstanceNormAttributes,
267
) => {
268
  const xShape = inputs[0].dims;
269
  const outputShape = xShape;
270
  const N = xShape[0];
271
  const C = xShape[xShape.length - 1];
272
  const H = ShapeUtil.sizeFromDimension(xShape, 1) / C;
273
  const components = getMaxComponents(C);
274
  const outputSize = ShapeUtil.size(outputShape) / components;
275
  const programUniforms: ProgramUniform[] = [
276
    { type: DataType.uint32, data: H },
277
    { type: DataType.uint32, data: Math.floor(C / components) },
278
  ];
279
  const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
280
  // first compute mean
281
  const channelScaleShift = computeMean(context, inputs[0], inputs[1], inputs[2], N, H, C, attributes.epsilon);
282
  const getShaderSource = (shaderHelper: ShaderHelper) => {
283
    const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
284
    const scaleType = components === 1 ? 'vec2f' : `mat2x${components}f`;
285
    const scaleCastType = components === 1 ? dataType : `vec${components}<${dataType}>`;
286

287
    const inputHelper = inputVariable('input', inputs[0].dataType, inputs[0].dims, components);
288
    const outputHelper = outputVariable('output', inputs[0].dataType, outputShape, components);
289

290
    return `
291
  @group(0) @binding(0) var<storage, read> input : array<${inputHelper.type.storage}>;
292
  @group(0) @binding(1) var<storage, read> scaleInput : array<${scaleType}>;
293
  @group(0) @binding(2) var<storage, read_write> output : array<${outputHelper.type.storage}>;
294
  struct Uniforms {H: u32, C : u32};
295
  @group(0) @binding(3) var<uniform> uniforms: Uniforms;
296

297
  ${shaderHelper.mainStart()}
298
    let currentImageNumber = global_idx / (uniforms.C * uniforms.H);
299
    let currentChannelNumber = global_idx % uniforms.C;
300

301
    let scaleOffset = currentImageNumber * uniforms.C + currentChannelNumber;
302
    let scale = scaleInput[scaleOffset];
303
    output[global_idx] = fma(input[global_idx], ${scaleCastType}(scale[0]), ${scaleCastType}(scale[1]));
304
  }`;
305
  };
306
  context.compute(
307
    {
308
      name: 'InstanceNormalizationNHWC',
309
      shaderCache: { hint: `${components}`, inputDependencies },
310
      getRunData: () => ({
311
        outputs: [{ dims: outputShape, dataType: inputs[0].dataType }],
312
        dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
313
        programUniforms,
314
      }),
315
      getShaderSource,
316
    },
317
    { inputs: [inputs[0], channelScaleShift] },
318
  );
319
};
320

321
export const instanceNorm = (context: ComputeContext, attributes: InstanceNormAttributes): void => {
322
  if (attributes.format === 'NHWC') {
323
    createInstanceNormNHWCProgramInfo(context, context.inputs, attributes);
324
  } else {
325
    context.compute(createInstanceNormProgramInfo(context.inputs, attributes));
326
  }
327
};
328

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.