onnxruntime

Форк
0
/
gather-block-quantized.ts 
196 строк · 8.8 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
8
import { ComputeContext, ProgramInfo, ProgramUniform } from '../types';
9

10
import {
11
  createTensorShapeVariables,
12
  inputVariable,
13
  outputVariable,
14
  ShaderHelper,
15
  tensorTypeToWsglValueType,
16
  UniformsArrayType,
17
} from './common';
18

19
export interface GatherBlockQuantizedAttributes extends AttributeWithCacheKey {
20
  gatherAxis: number;
21
  quantizeAxis: number;
22
  blockSize: number;
23
}
24

25
export const validateInputs = (inputs: readonly TensorView[], attributes: GatherBlockQuantizedAttributes): void => {
26
  if (inputs.length < 3 || inputs.length > 4) {
27
    throw new Error('GatherBlockQuantized requires 3 or 4 inputs.');
28
  }
29
  const quantizeAxis = ShapeUtil.normalizeAxis(attributes.quantizeAxis, inputs[0].dims.length);
30
  const blockSize = attributes.blockSize;
31
  const data = inputs[0];
32
  const scales = inputs[2];
33
  const zeroPoint = inputs.length === 4 ? inputs[3] : undefined;
34
  if (
35
    scales.dims.length !== data.dims.length ||
36
    !data.dims
37
      .map((d, i) => (i === quantizeAxis ? Math.ceil(d / blockSize) === scales.dims[i] : d === scales.dims[i]))
38
      .reduce((a, b) => a && b, true)
39
  ) {
40
    throw new Error(
41
      'Scales must have the same rank as the input tensor and the dims should match except on gatherAxis.',
42
    );
43
  }
44
  // TODO Uncomment the following check once the test case creation code is fixed to create data correctly aligned.
45
  // const indices = inputs[1];
46
  // const validIndex = (index: number) => index >= 0 && index < data.dims[attributes.gatherAxis];
47
  // if (indices.dataType === DataType.int32 && indices.getInt32Array().some((v) => !validIndex(v)) ||
48
  //     indices.dataType === DataType.int64 && indices.getBigInt64Array().some((v) => !validIndex(Number(v)))) {
49
  //   throw new Error('Indices must be within the bounds of the gatherAxis.');
50
  // }
51
  if (zeroPoint) {
52
    if (zeroPoint.dataType !== data.dataType) {
53
      throw new Error('Zero point must have the same data type as the input tensor.');
54
    }
55
    if (
56
      zeroPoint.dims.length !== scales.dims.length ||
57
      !zeroPoint.dims.map((d, i) => d === scales.dims[i]).reduce((a, b) => a && b, true)
58
    ) {
59
      throw new Error(
60
        'Zero point must have the same rank as the input tensor and the dims should match except on quantizeAxis.',
61
      );
62
    }
63
  }
64
};
65

66
const createGatherBlockQuantizedProgramInfo = (
67
  inputs: readonly TensorView[],
68
  attributes: GatherBlockQuantizedAttributes,
69
): ProgramInfo => {
70
  const inputShape = inputs[0].dims;
71
  const indicesShape = inputs[1].dims;
72
  const inputRank = inputShape.length;
73
  const gatherAxis = ShapeUtil.normalizeAxis(attributes.gatherAxis, inputRank);
74
  const quantizeAxis = ShapeUtil.normalizeAxis(attributes.quantizeAxis, inputRank);
75
  const outputShape = inputShape.slice(0);
76
  outputShape.splice(gatherAxis, 1, ...indicesShape);
77
  const outputSize = ShapeUtil.size(outputShape);
78
  const outputType = inputs[2].dataType;
79
  const inputType = inputs[0].dataType;
80
  const isSigned = inputType === DataType.int4; // input data type is either int4 or uint4.
81
  const programUniforms: ProgramUniform[] = [
82
    { type: DataType.uint32, data: outputSize },
83
    { type: DataType.uint32, data: quantizeAxis },
84
    { type: DataType.uint32, data: gatherAxis },
85
    { type: DataType.uint32, data: attributes.blockSize },
86
    ...createTensorShapeVariables(...inputs.map((input, _) => input.dims), outputShape),
87
  ];
88

89
  const getShaderSource = (shaderHelper: ShaderHelper) => {
90
    const data = inputVariable('data', inputs[0].dataType, inputs[0].dims.length);
91
    const indices = inputVariable('inputIndices', inputs[1].dataType, inputs[1].dims.length);
92
    const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length);
93
    const zeroPoint =
94
      inputs.length > 3 ? inputVariable('zeroPoint', inputs[3].dataType, inputs[3].dims.length) : undefined;
95
    const output = outputVariable('output', outputType, outputShape.length);
96
    const inputVariables = [data, indices, scales];
97
    if (zeroPoint) {
98
      inputVariables.push(zeroPoint);
99
    }
100
    const uniforms: UniformsArrayType = [
101
      { name: 'output_size', type: 'u32' },
102
      { name: 'quantize_axis', type: 'u32' },
103
      { name: 'gather_axis', type: 'u32' },
104
      { name: 'block_size', type: 'u32' },
105
    ];
106
    return `
107
        ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
108
        ${shaderHelper.mainStart()}
109
        let output_indices = ${output.offsetToIndices('global_idx')};
110
        var indices_indices = ${indices.type.indices}(0);
111
        ${(() => {
112
          if (indicesShape.length > 1) {
113
            return `
114
          for (var i: u32 = 0; i < ${indicesShape.length}; i++) {
115
            let index = ${output.indicesGet('output_indices', 'uniforms.gather_axis + i')};
116
            ${indices.indicesSet('indices_indices', 'i', 'index')};
117
          }`;
118
          } else {
119
            return `indices_indices = ${output.indicesGet('output_indices', 'uniforms.gather_axis')};`;
120
          }
121
        })()};
122
        var data_indices = ${data.type.indices}(0);
123
        for (var i: u32 = 0; i < uniforms.gather_axis; i++) {
124
          let index = ${output.indicesGet('output_indices', 'i')};
125
          ${data.indicesSet('data_indices', 'i', 'index')};
126
        }
127
        var index_from_indices = ${indices.getByIndices('indices_indices')};
128
        if (index_from_indices < 0) {
129
          index_from_indices += ${inputShape[gatherAxis]};
130
        }
131
        ${data.indicesSet('data_indices', 'uniforms.gather_axis', 'u32(index_from_indices)')};
132
        for (var i = uniforms.gather_axis + 1; i < ${outputShape.length}; i++) {
133
          let index = ${output.indicesGet('output_indices', `i + ${indicesShape.length} - 1`)};
134
          ${data.indicesSet('data_indices', 'i', 'index')};
135
        }
136
        let data_offset = ${data.indicesToOffset('data_indices')};
137
        let data_index = data_offset % 8;
138
        // Convert 4-bit packed data to 8-bit packed data.
139
        let packed_4bit_quantized_data = ${data.getByOffset('data_offset / 8')};
140
        let packed_8bit_quantized_data = (packed_4bit_quantized_data >> (4 * (data_index % 2))) & 0x0f0f0f0f;
141
        let quantized_data_vec = ${isSigned ? 'unpack4xI8' : 'unpack4xU8'}(u32(packed_8bit_quantized_data));
142
        let quantized_data = quantized_data_vec[data_index / 2];
143
        var scale_indices = data_indices;
144
        let quantize_axis_index = ${scales.indicesGet('data_indices', 'uniforms.quantize_axis')} / uniforms.block_size;
145
        ${scales.indicesSet('scale_indices', 'uniforms.quantize_axis', 'quantize_axis_index')};
146
        var scale = ${scales.getByIndices('scale_indices')};
147
        ${(() => {
148
          if (!zeroPoint) {
149
            return 'var zero_point = 0';
150
          } else {
151
            return `
152
              let zero_point_indices = scale_indices;
153
              let zero_point_offset = ${zeroPoint.indicesToOffset('zero_point_indices')};
154
              let zero_point_index = zero_point_offset % 8;
155
              let packed_4bit_zero_points = ${zeroPoint.getByOffset('zero_point_offset / 8')};
156
              let packed_8bit_zero_points = (packed_4bit_zero_points >> (4 * (zero_point_index % 2))) & 0x0f0f0f0f;
157
              let zero_point_vec = ${isSigned ? 'unpack4xI8' : 'unpack4xU8'}(u32(packed_8bit_zero_points));
158
              let zero_point = zero_point_vec[zero_point_index / 2];`;
159
          }
160
        })()};
161
        let dequantized_data = ${tensorTypeToWsglValueType(outputType)}(quantized_data - zero_point) * scale;
162
        ${output.setByOffset('global_idx', 'dequantized_data')};
163
    }`;
164
  };
165
  return {
166
    name: 'GatherBlockQuantized',
167
    shaderCache: {
168
      hint: `${attributes.cacheKey};${inputs
169
        .filter((_, i) => i !== 1)
170
        .map((input) => input.dims.join('_'))
171
        .join(';')}`,
172
      inputDependencies: Array.from({ length: inputs.length }, (_v, _i) => 'rank'),
173
    },
174
    getRunData: () => ({
175
      outputs: [{ dims: outputShape, dataType: outputType }],
176
      dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
177
      programUniforms,
178
    }),
179
    getShaderSource,
180
  };
181
};
182

183
export const gatherBlockQuantized = (context: ComputeContext, attributes: GatherBlockQuantizedAttributes): void => {
184
  const inputs = context.inputs;
185
  validateInputs(inputs, attributes);
186
  context.compute(createGatherBlockQuantizedProgramInfo(context.inputs, attributes));
187
};
188

189
export const parseGatherBlockQuantizedAttributes = (
190
  attributes: Record<string, unknown>,
191
): GatherBlockQuantizedAttributes =>
192
  createAttributeWithCacheKey({
193
    blockSize: attributes.blockSize as number,
194
    gatherAxis: attributes.gatherAxis as number,
195
    quantizeAxis: attributes.quantizeAxis as number,
196
  });
197

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.