1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
8
import { ComputeContext, ProgramInfo, ProgramUniform } from '../types';
11
createTensorShapeVariables,
19
export interface DequantizeLinerAttributes extends AttributeWithCacheKey {
24
const validateInputs = (inputs: readonly TensorView[], attributes: DequantizeLinerAttributes): void => {
25
if (inputs.length < 2 || inputs.length > 3) {
26
throw new Error('DequantizeLinear requires 2 or 3 inputs.');
28
if (inputs.length === 3 && inputs[1].dims === inputs[2].dims) {
29
throw new Error('x-scale and x-zero-point must have the same shape.');
31
if (inputs.length === 3 && inputs[0].dataType !== inputs[2].dataType) {
32
throw new Error('x and x-zero-point must have the same data type.');
34
if (inputs[0].dataType === DataType.int32 && inputs.length > 2) {
35
throw new Error('In the case of dequantizing int32 there is no zero point.');
37
if (inputs[1].dims.length !== 0 && inputs[1].dims.length !== 1 && inputs[1].dims.length !== inputs[0].dims.length) {
38
throw new Error('scale input must be a scalar, a 1D tensor, or have the same rank as the input tensor.');
40
// validate scale and zero-point input shapes
41
if (inputs.length > 2) {
42
// zero-point input type should be the same as input data type.
43
if (inputs[0].dataType !== inputs[2].dataType) {
44
throw new Error('x and x-zero-point must have the same data type.');
46
// Scale and zero-point inputs must have the same shape
47
if (inputs[1].dims.length !== inputs[2].dims.length) {
48
throw new Error('scale and zero-point inputs must have the same rank.');
50
if (!inputs[1].dims.map((d, i) => d === inputs[2].dims[i]).reduce((a, b) => a && b, true)) {
51
throw new Error('scale and zero-point inputs must have the same shape.');
55
if (attributes.blockSize > 0) {
57
if (inputs[1].dims.length === 0 || (inputs[1].dims.length === 1 && inputs[1].dims[0] === 1)) {
58
throw new Error('blockSize must be set only for block quantization.');
61
!inputs[1].dims.map((d, i) => i === attributes.axis || d === inputs[0].dims[i]).reduce((a, b) => a && b, true)
63
throw new Error('For block qunatization, scale input shape to match the input shape except for the axis');
65
// Scale input rank should be same as the input rank
66
if (inputs[1].dims.length !== inputs[0].dims.length) {
67
throw new Error('For block qunatization the scale input rank must be the same as the x rank.');
69
const dI = inputs[0].dims[attributes.axis];
70
const si = inputs[1].dims[attributes.axis];
71
if (attributes.blockSize < Math.ceil(dI / si) || attributes.blockSize > Math.ceil(dI / (si - 1) - 1)) {
72
throw new Error('blockSize must be with in the range [ceil(dI / Si), ceil(dI / (Si - 1) - 1)].');
77
const createDequantizeLinearProgramInfo = (
78
inputs: readonly TensorView[],
79
attributes: DequantizeLinerAttributes,
81
const axis = ShapeUtil.normalizeAxis(attributes.axis, inputs[0].dims.length);
82
const inputType = inputs[0].dataType;
83
const isSigned = inputType === DataType.int8;
84
const outputShape = inputs[0].dims; // output shape is same as the input shape
85
const dataType = inputs[1].dataType; // output type is same as the the scale input type
86
const outputSize = ShapeUtil.size(outputShape);
87
const isPacked = inputType === DataType.int8 || inputType === DataType.uint8;
88
const inputShape = isPacked ? [Math.ceil(ShapeUtil.size(inputs[0].dims) / 4)] : inputs[0].dims;
89
const scaleShape = inputs[1].dims;
90
const zeroPointInput = inputs.length > 2 ? inputs[2] : undefined;
91
const zeroPointShape = zeroPointInput
93
? [Math.ceil(ShapeUtil.size(zeroPointInput.dims) / 4)]
96
// Scales input is a scaler for per-tensor/per-layer quantization, 1-D tensor for per-axis quantization
97
// or tensor with same rank as input for blocked quantization.
98
const perLayerQuantization = scaleShape.length === 0 || (scaleShape.length === 1 && scaleShape[0] === 1);
99
const perAxisQuantization = perLayerQuantization === false && scaleShape.length === 1;
100
// Left unnecessary commented-out assignment for documentation
101
// const blockQuantization = perLayerQuantization === false && perAxisQuantization === false;
102
const maxComponents = getMaxComponents(outputSize);
103
const useComponents = perLayerQuantization && (!isPacked || maxComponents === 4);
104
const components = useComponents ? maxComponents : 1;
105
const inputComponent = useComponents && !isPacked ? maxComponents : 1;
106
const input = inputVariable('input', isPacked ? DataType.uint32 : inputType, inputShape.length, inputComponent);
107
const scale = inputVariable('scale', dataType, scaleShape.length);
108
const zeroPoint = zeroPointInput
109
? inputVariable('zero_point', isPacked ? DataType.uint32 : inputType, zeroPointShape!.length)
111
const output = outputVariable('output', dataType, outputShape.length, components);
112
const inputVariables = [input, scale];
114
inputVariables.push(zeroPoint);
116
const inputShapes = [inputShape, scaleShape];
117
if (zeroPointInput) {
118
inputShapes.push(zeroPointShape!);
120
const programUniforms: ProgramUniform[] = [
121
{ type: DataType.uint32, data: outputSize / components },
122
{ type: DataType.uint32, data: axis },
123
{ type: DataType.uint32, data: attributes.blockSize },
124
...createTensorShapeVariables(...inputShapes, outputShape),
126
const getShaderSource = (shaderHelper: ShaderHelper) => {
127
const uniforms: UniformsArrayType = [
128
{ name: 'output_size', type: 'u32' },
129
{ name: 'axis', type: 'u32' },
130
{ name: 'block_size', type: 'u32' },
133
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
134
${shaderHelper.mainStart()}
135
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
136
let output_indices = ${output.offsetToIndices('global_idx')};
142
let input = ${input.getByOffset('global_idx / 4')};
143
let x_vec = ${isSigned ? 'unpack4xI8(input)' : 'unpack4xU8(input)'};
144
let x_value = ${components === 1 ? 'x_vec[global_idx % 4]' : 'x_vec'};`;
146
return `let x_value = ${input.getByOffset('global_idx')};`;
152
if (perLayerQuantization) {
153
// scale input is a scalar ()
154
return `let scale_value= ${scale.getByOffset('0')}`;
155
} else if (perAxisQuantization) {
156
// scale input is a 1D tensor
158
let scale_index = ${output.indicesGet('output_indices', 'uniforms.axis')};
159
let scale_value= ${scale.getByOffset('scale_index')};`;
161
// Block quantization. Scale input rank is same as input/output rank.
163
var scale_indices: ${scale.type.indices} = output_indices;
164
let index = ${scale.indicesGet('scale_indices', 'uniforms.axis')} / uniforms.block_size;
165
${scale.indicesSet('scale_indices', 'uniforms.axis', 'index')};
166
let scale_value= ${scale.getByIndices('scale_indices')};`;
170
// Set zero-point input
173
if (perLayerQuantization) {
174
// zero-point input is a scalar
177
let zero_point_input = ${zeroPoint.getByOffset('0')};
178
let zero_point_vec = ${isSigned ? 'unpack4xI8(zero_point_input)' : 'unpack4xU8(zero_point_input)'};
179
let zero_point_value= zero_point_vec[0]`;
181
return `let zero_point_value = ${zeroPoint.getByOffset('0')}`;
183
} else if (perAxisQuantization) {
184
// zero-point input is a 1D tensor
187
let zero_point_index = ${output.indicesGet('output_indices', 'uniforms.axis')};
188
let zero_point_input = ${zeroPoint.getByOffset('zero_point_index / 4')};
189
let zero_point_vec = ${isSigned ? 'unpack4xI8(zero_point_input)' : 'unpack4xU8(zero_point_input)'};
190
let zero_point_value = zero_point_vec[zero_point_index % 4]`;
193
let zero_point_index = ${output.indicesGet('output_indices', 'uniforms.axis')};
194
let zero_point_value = ${zeroPoint.getByOffset('zero_point_index')};`;
197
// BlockedQuantization. The zero-point input shape is same as the input shape except along axis.
200
let zero_point_offset = ${scale.indicesToOffset('scale_indices')};
201
let zero_point_input = ${zeroPoint.getByOffset('zero_point_offset / 4')};
202
let zero_point_vec = ${isSigned ? 'unpack4xI8(zero_point_input)' : 'unpack4xU8(zero_point_input)'};
203
let zero_point_value = zero_point_vec[zero_point_offset % 4];`;
205
return `let zero_point_value = ${zeroPoint.getByIndices('scale_indices')};`;
209
return `let zero_point_value = ${isPacked ? (isSigned ? 'i32' : 'u32') : input.type.value}(0);`;
212
// Compute and write output
213
${output.setByOffset('global_idx', `${output.type.value}(x_value - zero_point_value) * scale_value`)};
217
name: 'DequantizeLinear',
219
hint: attributes.cacheKey,
220
inputDependencies: zeroPoint ? ['rank', 'rank', 'rank'] : ['rank', 'rank'],
224
outputs: [{ dims: outputShape, dataType }],
225
dispatchGroup: { x: Math.ceil(outputSize / components / 64), y: 1, z: 1 },
231
export const dequantizeLinear = (context: ComputeContext, attributes: DequantizeLinerAttributes): void => {
232
validateInputs(context.inputs, attributes);
233
context.compute(createDequantizeLinearProgramInfo(context.inputs, attributes));
236
export const parseDequantizeLinearAttributes = (attributes: Record<string, unknown>): DequantizeLinerAttributes =>
237
createAttributeWithCacheKey({ axis: attributes.axis as number, blockSize: attributes.blockSize as number });