onnxruntime

Форк
0
/
quantize-linear.ts 
237 строк · 11.0 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
8
import { ComputeContext, ProgramInfo, ProgramUniform } from '../types';
9

10
import {
11
  createTensorShapeVariables,
12
  getMaxComponents,
13
  inputVariable,
14
  outputVariable,
15
  ShaderHelper,
16
  UniformsArrayType,
17
} from './common';
18

19
export interface DequantizeLinerAttributes extends AttributeWithCacheKey {
20
  axis: number;
21
  blockSize: number;
22
}
23

24
const validateInputs = (inputs: readonly TensorView[], attributes: DequantizeLinerAttributes): void => {
25
  if (inputs.length < 2 || inputs.length > 3) {
26
    throw new Error('DequantizeLinear requires 2 or 3 inputs.');
27
  }
28
  if (inputs.length === 3 && inputs[1].dims === inputs[2].dims) {
29
    throw new Error('x-scale and x-zero-point must have the same shape.');
30
  }
31
  if (inputs.length === 3 && inputs[0].dataType !== inputs[2].dataType) {
32
    throw new Error('x and x-zero-point must have the same data type.');
33
  }
34
  if (inputs[0].dataType === DataType.int32 && inputs.length > 2) {
35
    throw new Error('In the case of dequantizing int32 there is no zero point.');
36
  }
37
  if (inputs[1].dims.length !== 0 && inputs[1].dims.length !== 1 && inputs[1].dims.length !== inputs[0].dims.length) {
38
    throw new Error('scale input must be a scalar, a 1D tensor, or have the same rank as the input tensor.');
39
  }
40
  // validate scale and zero-point input shapes
41
  if (inputs.length > 2) {
42
    // zero-point input type should be the same as input data type.
43
    if (inputs[0].dataType !== inputs[2].dataType) {
44
      throw new Error('x and x-zero-point must have the same data type.');
45
    }
46
    // Scale and zero-point inputs must have the same shape
47
    if (inputs[1].dims.length !== inputs[2].dims.length) {
48
      throw new Error('scale and zero-point inputs must have the same rank.');
49
    }
50
    if (!inputs[1].dims.map((d, i) => d === inputs[2].dims[i]).reduce((a, b) => a && b, true)) {
51
      throw new Error('scale and zero-point inputs must have the same shape.');
52
    }
53
  }
54
  // Validate blockSize
55
  if (attributes.blockSize > 0) {
56
    // Block qunatization
57
    if (inputs[1].dims.length === 0 || (inputs[1].dims.length === 1 && inputs[1].dims[0] === 1)) {
58
      throw new Error('blockSize must be set only for block quantization.');
59
    }
60
    if (
61
      !inputs[1].dims.map((d, i) => i === attributes.axis || d === inputs[0].dims[i]).reduce((a, b) => a && b, true)
62
    ) {
63
      throw new Error('For block qunatization, scale input shape to match the input shape except for the axis');
64
    }
65
    // Scale input rank should be same as the input rank
66
    if (inputs[1].dims.length !== inputs[0].dims.length) {
67
      throw new Error('For block qunatization the scale input rank must be the same as the x rank.');
68
    }
69
    const dI = inputs[0].dims[attributes.axis];
70
    const si = inputs[1].dims[attributes.axis];
71
    if (attributes.blockSize < Math.ceil(dI / si) || attributes.blockSize > Math.ceil(dI / (si - 1) - 1)) {
72
      throw new Error('blockSize must be with in the range [ceil(dI / Si), ceil(dI / (Si - 1) - 1)].');
73
    }
74
  }
75
};
76

77
const createDequantizeLinearProgramInfo = (
78
  inputs: readonly TensorView[],
79
  attributes: DequantizeLinerAttributes,
80
): ProgramInfo => {
81
  const axis = ShapeUtil.normalizeAxis(attributes.axis, inputs[0].dims.length);
82
  const inputType = inputs[0].dataType;
83
  const isSigned = inputType === DataType.int8;
84
  const outputShape = inputs[0].dims; // output shape is same as the input shape
85
  const dataType = inputs[1].dataType; // output type is same as the the scale input type
86
  const outputSize = ShapeUtil.size(outputShape);
87
  const isPacked = inputType === DataType.int8 || inputType === DataType.uint8;
88
  const inputShape = isPacked ? [Math.ceil(ShapeUtil.size(inputs[0].dims) / 4)] : inputs[0].dims;
89
  const scaleShape = inputs[1].dims;
90
  const zeroPointInput = inputs.length > 2 ? inputs[2] : undefined;
91
  const zeroPointShape = zeroPointInput
92
    ? isPacked
93
      ? [Math.ceil(ShapeUtil.size(zeroPointInput.dims) / 4)]
94
      : zeroPointInput.dims
95
    : undefined;
96
  // Scales input is a scaler for per-tensor/per-layer quantization, 1-D tensor for per-axis quantization
97
  // or tensor with same rank as input for blocked quantization.
98
  const perLayerQuantization = scaleShape.length === 0 || (scaleShape.length === 1 && scaleShape[0] === 1);
99
  const perAxisQuantization = perLayerQuantization === false && scaleShape.length === 1;
100
  // Left unnecessary commented-out assignment for documentation
101
  // const blockQuantization = perLayerQuantization === false && perAxisQuantization === false;
102
  const maxComponents = getMaxComponents(outputSize);
103
  const useComponents = perLayerQuantization && (!isPacked || maxComponents === 4);
104
  const components = useComponents ? maxComponents : 1;
105
  const inputComponent = useComponents && !isPacked ? maxComponents : 1;
106
  const input = inputVariable('input', isPacked ? DataType.uint32 : inputType, inputShape.length, inputComponent);
107
  const scale = inputVariable('scale', dataType, scaleShape.length);
108
  const zeroPoint = zeroPointInput
109
    ? inputVariable('zero_point', isPacked ? DataType.uint32 : inputType, zeroPointShape!.length)
110
    : undefined;
111
  const output = outputVariable('output', dataType, outputShape.length, components);
112
  const inputVariables = [input, scale];
113
  if (zeroPoint) {
114
    inputVariables.push(zeroPoint);
115
  }
116
  const inputShapes = [inputShape, scaleShape];
117
  if (zeroPointInput) {
118
    inputShapes.push(zeroPointShape!);
119
  }
120
  const programUniforms: ProgramUniform[] = [
121
    { type: DataType.uint32, data: outputSize / components },
122
    { type: DataType.uint32, data: axis },
123
    { type: DataType.uint32, data: attributes.blockSize },
124
    ...createTensorShapeVariables(...inputShapes, outputShape),
125
  ];
126
  const getShaderSource = (shaderHelper: ShaderHelper) => {
127
    const uniforms: UniformsArrayType = [
128
      { name: 'output_size', type: 'u32' },
129
      { name: 'axis', type: 'u32' },
130
      { name: 'block_size', type: 'u32' },
131
    ];
132
    return `
133
      ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
134
      ${shaderHelper.mainStart()}
135
          ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
136
          let output_indices = ${output.offsetToIndices('global_idx')};
137

138
          // Set input x
139
          ${(() => {
140
            if (isPacked) {
141
              return `
142
            let input = ${input.getByOffset('global_idx / 4')};
143
            let x_vec = ${isSigned ? 'unpack4xI8(input)' : 'unpack4xU8(input)'};
144
            let x_value = ${components === 1 ? 'x_vec[global_idx % 4]' : 'x_vec'};`;
145
            } else {
146
              return `let x_value = ${input.getByOffset('global_idx')};`;
147
            }
148
          })()};
149

150
          // Set scale input
151
          ${(() => {
152
            if (perLayerQuantization) {
153
              // scale input is a scalar ()
154
              return `let scale_value= ${scale.getByOffset('0')}`;
155
            } else if (perAxisQuantization) {
156
              // scale input is a 1D tensor
157
              return `
158
            let scale_index = ${output.indicesGet('output_indices', 'uniforms.axis')};
159
            let scale_value= ${scale.getByOffset('scale_index')};`;
160
            } else {
161
              // Block quantization. Scale input rank is same as input/output rank.
162
              return `
163
            var scale_indices: ${scale.type.indices} = output_indices;
164
            let index = ${scale.indicesGet('scale_indices', 'uniforms.axis')} / uniforms.block_size;
165
            ${scale.indicesSet('scale_indices', 'uniforms.axis', 'index')};
166
            let scale_value= ${scale.getByIndices('scale_indices')};`;
167
            }
168
          })()};
169

170
          // Set zero-point input
171
          ${(() => {
172
            if (zeroPoint) {
173
              if (perLayerQuantization) {
174
                // zero-point input is a scalar
175
                if (isPacked) {
176
                  return `
177
                let zero_point_input = ${zeroPoint.getByOffset('0')};
178
                let zero_point_vec =  ${isSigned ? 'unpack4xI8(zero_point_input)' : 'unpack4xU8(zero_point_input)'};
179
                let zero_point_value= zero_point_vec[0]`;
180
                } else {
181
                  return `let zero_point_value = ${zeroPoint.getByOffset('0')}`;
182
                }
183
              } else if (perAxisQuantization) {
184
                // zero-point input is a 1D tensor
185
                if (isPacked) {
186
                  return `
187
                let zero_point_index = ${output.indicesGet('output_indices', 'uniforms.axis')};
188
                let zero_point_input = ${zeroPoint.getByOffset('zero_point_index / 4')};
189
                let zero_point_vec =  ${isSigned ? 'unpack4xI8(zero_point_input)' : 'unpack4xU8(zero_point_input)'};
190
                let zero_point_value = zero_point_vec[zero_point_index % 4]`;
191
                } else {
192
                  return `
193
                let zero_point_index = ${output.indicesGet('output_indices', 'uniforms.axis')};
194
                let zero_point_value = ${zeroPoint.getByOffset('zero_point_index')};`;
195
                }
196
              } else {
197
                // BlockedQuantization. The zero-point input shape is same as the input shape except along axis.
198
                if (isPacked) {
199
                  return `
200
                let zero_point_offset = ${scale.indicesToOffset('scale_indices')};
201
                let zero_point_input = ${zeroPoint.getByOffset('zero_point_offset / 4')};
202
                let zero_point_vec = ${isSigned ? 'unpack4xI8(zero_point_input)' : 'unpack4xU8(zero_point_input)'};
203
                let zero_point_value = zero_point_vec[zero_point_offset % 4];`;
204
                } else {
205
                  return `let zero_point_value = ${zeroPoint.getByIndices('scale_indices')};`;
206
                }
207
              }
208
            } else {
209
              return `let zero_point_value = ${isPacked ? (isSigned ? 'i32' : 'u32') : input.type.value}(0);`;
210
            }
211
          })()};
212
      // Compute and write output
213
      ${output.setByOffset('global_idx', `${output.type.value}(x_value - zero_point_value) * scale_value`)};
214
      }`;
215
  };
216
  return {
217
    name: 'DequantizeLinear',
218
    shaderCache: {
219
      hint: attributes.cacheKey,
220
      inputDependencies: zeroPoint ? ['rank', 'rank', 'rank'] : ['rank', 'rank'],
221
    },
222
    getShaderSource,
223
    getRunData: () => ({
224
      outputs: [{ dims: outputShape, dataType }],
225
      dispatchGroup: { x: Math.ceil(outputSize / components / 64), y: 1, z: 1 },
226
      programUniforms,
227
    }),
228
  };
229
};
230

231
export const dequantizeLinear = (context: ComputeContext, attributes: DequantizeLinerAttributes): void => {
232
  validateInputs(context.inputs, attributes);
233
  context.compute(createDequantizeLinearProgramInfo(context.inputs, attributes));
234
};
235

236
export const parseDequantizeLinearAttributes = (attributes: Record<string, unknown>): DequantizeLinerAttributes =>
237
  createAttributeWithCacheKey({ axis: attributes.axis as number, blockSize: attributes.blockSize as number });
238

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.