onnxruntime

Форк
0
153 строки · 6.5 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { env } from 'onnxruntime-common';
5

6
import { DataType } from '../../../wasm-common';
7
import { TensorView } from '../../tensor-view';
8
import { ShapeUtil } from '../../util';
9
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
10
import { ComputeContext, ProgramInfo } from '../types';
11

12
import { createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper } from './common';
13

14
export interface BatchNormAttributes extends AttributeWithCacheKey {
15
  readonly epsilon: number;
16
  readonly momentum: number;
17
  readonly spatial: boolean;
18
  readonly trainingMode: boolean;
19
  readonly format: 'NHWC' | 'NCHW';
20
  readonly outputCount: number;
21
}
22

23
const validateInputs = (inputs: readonly TensorView[], attributes: BatchNormAttributes): void => {
24
  if (!inputs || inputs.length !== 5) {
25
    throw new Error('BatchNormalization requires 5 inputs');
26
  }
27

28
  const checkShapeEqual = (actual: readonly number[], expected: readonly number[], message: string) => {
29
    const r = expected.length;
30
    if (r !== actual.length) {
31
      throw new Error(`${message}: num dimensions != ${r}`);
32
    }
33
    expected.forEach((v, i) => {
34
      if (v !== actual[i]) {
35
        throw new Error(`${message}: dim[${i}] do not match`);
36
      }
37
    });
38
  };
39

40
  if (inputs[0].dims.length > 1) {
41
    const shape =
42
      attributes.format === 'NHWC'
43
        ? attributes.spatial
44
          ? inputs[0].dims.slice(-1)
45
          : inputs[0].dims.slice(-1).concat(inputs[0].dims.slice(1, inputs[0].dims.length - 1))
46
        : inputs[0].dims.slice(1, attributes.spatial ? 2 : undefined);
47
    checkShapeEqual(inputs[1].dims, shape, 'Invalid input scale');
48
    checkShapeEqual(inputs[2].dims, shape, 'Invalid input B');
49
    checkShapeEqual(inputs[3].dims, shape, 'Invalid input mean');
50
    checkShapeEqual(inputs[4].dims, shape, 'Invalid input var');
51
  } else {
52
    checkShapeEqual(inputs[1].dims, [1], 'Invalid input scale');
53
    checkShapeEqual(inputs[2].dims, [1], 'Invalid input B');
54
    checkShapeEqual(inputs[3].dims, [1], 'Invalid input mean');
55
    checkShapeEqual(inputs[4].dims, [1], 'Invalid input var');
56
  }
57
};
58

59
const createBatchNormInferenceProgramInfo = (
60
  inputs: readonly TensorView[],
61
  attributes: BatchNormAttributes,
62
): ProgramInfo => {
63
  const { epsilon, spatial, format } = attributes;
64
  const yShape = inputs[0].dims;
65
  const components = spatial ? getMaxComponents(yShape[yShape.length - 1]) : 1;
66
  const cComponents = format === 'NHWC' && yShape.length > 1 ? components : 1;
67
  const outputSize = ShapeUtil.size(yShape) / components;
68
  // Only support uniforms for opset version >= 9 (spatial = true).
69
  const useShapesUniforms = spatial;
70
  const shapeOrRank = useShapesUniforms ? yShape.length : yShape;
71
  const x = inputVariable('x', inputs[0].dataType, inputs[0].dims, components);
72
  const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims, cComponents);
73
  const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims, cComponents);
74
  const inputMean = inputVariable('inputMean', inputs[3].dataType, inputs[3].dims, cComponents);
75
  const inputVar = inputVariable('inputVar', inputs[4].dataType, inputs[4].dims, cComponents);
76
  const y = outputVariable('y', inputs[0].dataType, shapeOrRank, components);
77
  // TODO: support inputs with different data type. Current we need to make sure all inputs have the same data type.
78
  // Otherwise, the shader compilation will fail.
79
  const calcCOffset = (): string => {
80
    let cOffset = '';
81
    if (spatial) {
82
      cOffset = `let cOffset = ${
83
        yShape.length === 1
84
          ? '0u'
85
          : format === 'NHWC'
86
            ? `outputIndices[${yShape.length - 1}] / ${components}`
87
            : 'outputIndices[1]'
88
      };`;
89
    } else {
90
      if (format === 'NCHW') {
91
        cOffset = `
92
            ${y.indicesSet('outputIndices', '0', '0')}
93
            let cOffset = ${y.indicesToOffset('outputIndices')};`;
94
      } else {
95
        // update C channel.
96
        cOffset = `var cIndices = ${scale.type.indices}(0);
97
                       cIndices[0] = outputIndices[${yShape.length - 1}];`;
98
        // update D1 x ... x Dn channels.
99
        for (let i = 1; i < scale.rank; i++) {
100
          cOffset += `cIndices[${i}] = outputIndices[${i}];`;
101
        }
102
        cOffset += `let cOffset = ${scale.indicesToOffset('cIndices')};`;
103
      }
104
    }
105
    return cOffset;
106
  };
107
  const getInferenceModeShaderSource = (helper: ShaderHelper) => `
108
  const epsilon = ${epsilon};
109
  ${helper.registerUniform('outputSize', 'u32').declareVariables(x, scale, bias, inputMean, inputVar, y)}
110
  ${helper.mainStart()}
111
  ${helper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
112
    var outputIndices = ${y.offsetToIndices(`global_idx * ${components}`)};
113
    ${calcCOffset()}
114
    let scale = ${scale.getByOffset('cOffset')};
115
    let bias = ${bias.getByOffset('cOffset')};
116
    let inputMean = ${inputMean.getByOffset('cOffset')};
117
    let inputVar = ${inputVar.getByOffset('cOffset')};
118
    let x = ${x.getByOffset('global_idx')};
119
    let value = (x - inputMean) * inverseSqrt(inputVar + epsilon) * scale + bias;
120
    ${y.setByOffset('global_idx', 'value')}
121
  }`;
122
  return {
123
    name: 'BatchNormalization',
124
    shaderCache: {
125
      hint: `${attributes.epsilon}_${attributes.format}_${spatial}_${components}`,
126
      inputDependencies: useShapesUniforms ? ['rank', 'type', 'type', 'type', 'type'] : undefined,
127
    },
128
    getShaderSource: getInferenceModeShaderSource,
129
    getRunData: () => ({
130
      outputs: [{ dims: inputs[0].dims, dataType: inputs[0].dataType }],
131
      dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
132
      programUniforms: useShapesUniforms
133
        ? [{ type: DataType.uint32, data: outputSize }, ...createTensorShapeVariables(yShape)]
134
        : [{ type: DataType.uint32, data: outputSize }],
135
    }),
136
  };
137
};
138

139
export const parseBatchNormAttributes = (attributes: Record<string, unknown>): BatchNormAttributes =>
140
  createAttributeWithCacheKey(attributes as Omit<BatchNormAttributes, keyof AttributeWithCacheKey>);
141

142
export const batchNorm = (context: ComputeContext, attributes: Record<string, unknown>): void => {
143
  const { inputs, outputCount } = context;
144
  const updatedAttributes = parseBatchNormAttributes({ ...attributes, outputCount });
145
  if (env.webgpu.validateInputContent) {
146
    validateInputs(inputs, updatedAttributes);
147
  }
148
  if (attributes.trainingMode) {
149
    throw new Error('BatchNormalization trainingMode is not supported yet.');
150
  } else {
151
    context.compute(createBatchNormInferenceProgramInfo(inputs, updatedAttributes));
152
  }
153
};
154

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.