onnxruntime

Форк
0
/
matmulnbits.ts 
275 строк · 11.3 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
8
import { ComputeContext, ProgramInfo, ProgramUniform } from '../types';
9

10
import {
11
  createTensorShapeVariables,
12
  getMaxComponents,
13
  inputVariable,
14
  outputVariable,
15
  ShaderHelper,
16
  tensorTypeToWsglStorageType,
17
} from './common';
18

19
//  TODO support quantization bits not equal to 4
20
export interface MatMulNBitsAttributes extends AttributeWithCacheKey {
21
  k: number;
22
  n: number;
23
  accuracyLevel: number;
24
  bits: number;
25
  blockSize: number;
26
}
27

28
const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): void => {
29
  if (inputs.length < 3 || inputs.length > 4) {
30
    throw new Error('MatMulNBits requires 3 or 4 inputs');
31
  }
32
  const a = inputs[0];
33
  const aRank = a.dims.length;
34
  if (a.dims[aRank - 1] !== attributes.k) {
35
    throw new Error('The last dim of input shape does not match the k value');
36
  }
37
  const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize);
38
  const blobSize = (attributes.blockSize / 8) * attributes.bits;
39
  const b = inputs[1];
40
  if (!ShapeUtil.areEqual(b.dims, [attributes.n, nBlocksPerCol, blobSize])) {
41
    throw new Error('The second inputs must be 3D tensor with shape N X nBlocksPerCol X blobSize');
42
  }
43
  const scales = inputs[2];
44
  const scalesShape = scales.dims;
45
  if (ShapeUtil.size(scalesShape) !== attributes.n * nBlocksPerCol) {
46
    throw new Error('scales input size error.');
47
  }
48
  if (inputs.length === 4) {
49
    const zeroPoints = inputs[3];
50
    const zeroPointsShape = zeroPoints.dims;
51
    const expectedZeroPointsSize =
52
      attributes.bits > 4 ? attributes.n * nBlocksPerCol : attributes.n * Math.floor((nBlocksPerCol + 1) / 2);
53
    if (ShapeUtil.size(zeroPointsShape) !== expectedZeroPointsSize) {
54
      throw new Error('zeroPoints input size error.');
55
    }
56
  }
57
};
58

59
export const createMatMulNBitsProgramInfo = (
60
  inputs: readonly TensorView[],
61
  attributes: MatMulNBitsAttributes,
62
): ProgramInfo => {
63
  const inputShape = inputs[0].dims;
64
  const aRank = inputShape.length;
65
  const dimAOuter = inputShape[aRank - 2];
66
  const dimInner = attributes.k;
67
  const dimBOuter = attributes.n;
68
  const batchDims = inputShape.slice(0, aRank - 2);
69
  const batchSize = ShapeUtil.size(batchDims);
70
  const blobSize = inputs[1].dims[2];
71
  const blobSizeInWords = blobSize / 4;
72
  const dataType = inputs[0].dataType;
73
  const aComponents = getMaxComponents(attributes.k);
74
  const bComponents = getMaxComponents(blobSizeInWords);
75
  const components = getMaxComponents(dimBOuter);
76
  const outputShape = batchDims.concat([dimAOuter, dimBOuter]);
77
  const outputNumber = dimAOuter > 1 && (dimBOuter / components) % 2 === 0 ? 2 : 1;
78
  const dispatchSize = ShapeUtil.size(outputShape) / components / outputNumber;
79

80
  const workgroupSize = 64;
81

82
  const programUniforms: ProgramUniform[] = [];
83
  const inputShapeTemp = [batchSize, dimAOuter, dimInner / aComponents];
84
  const bShape = ShapeUtil.convertShape(inputs[1].dims).slice();
85
  bShape.splice(-1, 1, blobSizeInWords / bComponents);
86
  programUniforms.push(...createTensorShapeVariables(inputShapeTemp));
87
  programUniforms.push(...createTensorShapeVariables(bShape));
88
  programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
89
  if (inputs.length === 4) {
90
    programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(inputs[3].dims)));
91
  }
92
  const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components];
93
  programUniforms.push(...createTensorShapeVariables(outputShapeTemp));
94

95
  const getShaderSource = (shaderHelper: ShaderHelper) => {
96
    const inputRank = inputShapeTemp.length;
97
    const a = inputVariable('a', inputs[0].dataType, inputRank, aComponents);
98
    const b = inputVariable('b', DataType.uint32, bShape.length, bComponents);
99
    const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length);
100
    const inputVariables = [a, b, scales];
101
    const zeroPoints =
102
      inputs.length === 4 ? inputVariable('zero_points', DataType.uint32, inputs[3].dims.length) : undefined;
103
    if (zeroPoints) {
104
      inputVariables.push(zeroPoints);
105
    }
106
    const outputRank = outputShapeTemp.length;
107
    const output = outputVariable('output', inputs[0].dataType, outputRank, components);
108
    const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
109

110
    const qDqDataType = (() => {
111
      switch (aComponents) {
112
        case 1:
113
          return `array<${dataType}, 8>`;
114
        case 2:
115
          return `mat4x2<${dataType}>`;
116
        case 4:
117
          return `mat2x4<${dataType}>`;
118
        default:
119
          throw new Error(`${aComponents}-component is not supported.`);
120
      }
121
    })();
122

123
    const processOneWord = (): string => {
124
      let calcStr = `
125
          // reuse a data
126
            var input_offset = ${a.indicesToOffset(`${a.type.indices}(batch, row, word_offset)`)};
127
            var a_data: ${qDqDataType};
128
            for (var j: u32 = 0; j < ${8 / aComponents}; j++) {
129
              a_data[j] = ${a.getByOffset('input_offset')};
130
              input_offset++;
131
            }
132
          `;
133
      for (let c = 0; c < components * outputNumber; c++) {
134
        calcStr += `
135
            b_value = ${bComponents === 1 ? `b${c}_data` : `b${c}_data[i]`};
136
            b_value_lower = unpack4xU8(b_value & b_mask);
137
            b_value_upper = unpack4xU8((b_value >> 4) & b_mask);
138
            b_quantized_values = ${qDqDataType}(${Array.from(
139
              { length: 4 },
140
              (_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`,
141
            ).join(', ')});
142
            b_dequantized_values = ${(() => {
143
              if (aComponents === 1) {
144
                return `${qDqDataType}(${Array.from(
145
                  { length: 8 },
146
                  (_, i) => `(b_quantized_values[${i}] - ${zeroPoints ? `zero_point${c}` : 'zero_point'}) * scale${c}`,
147
                ).join(', ')});`;
148
              } else {
149
                return `(b_quantized_values - ${qDqDataType}(${Array(8)
150
                  .fill(`${zeroPoints ? `zero_point${c}` : 'zero_point'}`)
151
                  .join(',')})) * scale${c};`;
152
              }
153
            })()};
154
            workgroup_shared[local_id.x * ${outputNumber} + ${Math.floor(c / components)}]${components > 1 ? `[${c % components}]` : ''} += ${Array.from(
155
              { length: 8 / aComponents },
156
              (_, i) =>
157
                `${
158
                  aComponents === 1
159
                    ? `a_data[${i}] * b_dequantized_values[${i}]`
160
                    : `dot(a_data[${i}], b_dequantized_values[${i}])`
161
                }`,
162
            ).join(' + ')};
163
          `;
164
      }
165
      return calcStr;
166
    };
167
    const prepareScaleAndZeroPoint = (): string => {
168
      let calcStr = `
169
            var col_index = col * ${components};
170
            ${
171
              zeroPoints
172
                ? `
173
            let zero_point_bytes_per_col = (nBlocksPerCol + 1) / 2;
174
            var zero_point_byte_count: u32;
175
            var zero_point_word_index: u32;
176
            var zero_point_byte_offset: u32;
177
            let zero_point_nibble_offset: u32 = block & 0x1u;
178
            var zero_point_bits_offset: u32;
179
            var zero_point_word: u32;`
180
                : `
181
            // The default zero point is 8 for unsigned 4-bit quantization.
182
            let zero_point = ${dataType}(${8.0});`
183
            }
184
            `;
185
      for (let c = 0; c < components * outputNumber; c++) {
186
        calcStr += `
187
            let scale${c} = ${scales.getByOffset(`col_index * nBlocksPerCol + block`)};
188
            ${
189
              zeroPoints
190
                ? `
191
            zero_point_byte_count = col_index * zero_point_bytes_per_col + (block >> 0x1u);
192
            zero_point_word_index = zero_point_byte_count >> 0x2u;
193
            zero_point_byte_offset = zero_point_byte_count & 0x3u;
194
            zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);
195
            zero_point_word = ${zeroPoints.getByOffset('zero_point_word_index')} >> zero_point_bits_offset;
196
            let zero_point${c} = ${dataType}((zero_point_word) & 0xFu);`
197
                : ''
198
            }
199
            col_index += 1;`;
200
      }
201
      return calcStr;
202
    };
203
    const prepareBData = (): string => {
204
      let calcStr = `col_index = col * ${components};`;
205
      for (let c = 0; c < components * outputNumber; c++) {
206
        calcStr += `
207
            let b${c}_data = ${b.getByIndices(`${b.type.indices}(col_index, block, word)`)};
208
            col_index += 1;`;
209
      }
210
      calcStr += `
211
            var b_value: u32;
212
            let b_mask: u32 = 0x0F0F0F0Fu;
213
            var b_value_lower: vec4<u32>;
214
            var b_value_upper: vec4<u32>;
215
            var b_quantized_values: ${qDqDataType};
216
            var b_dequantized_values: ${qDqDataType};`;
217
      return calcStr;
218
    };
219
    return `
220
        var<workgroup> workgroup_shared: array<${output.type.value}, ${outputNumber * workgroupSize}>;
221
        ${shaderHelper.declareVariables(...inputVariables, output)}
222
        ${shaderHelper.mainStart([workgroupSize, 1, 1])}
223
          let output_indices = ${output.offsetToIndices(`(global_idx / ${workgroupSize}) * ${outputNumber}`)};
224
          let col = output_indices[2];
225
          let row = output_indices[1];
226
          let batch = output_indices[0];
227
          let nBlocksPerCol = uniforms.b_shape[1];
228

229
          for (var block = local_id.x; block < nBlocksPerCol; block += ${workgroupSize}) {
230
            //process one block
231
            var word_offset: u32 = block * ${attributes.blockSize / aComponents};
232
            ${prepareScaleAndZeroPoint()}
233
            for (var word: u32 = 0; word < ${blobSizeInWords}; word += ${bComponents}) {
234
              ${prepareBData()}
235
              for (var i: u32 = 0; i < ${bComponents}; i++) {
236
                ${processOneWord()}
237
                word_offset += ${8 / aComponents};
238
              }
239
            }
240
          }
241
          workgroupBarrier();
242

243
          if (local_id.x < ${outputNumber}) {
244
            var output_value: ${output.type.value} = ${output.type.value}(0);
245
            var workgroup_shared_offset: u32 = local_id.x;
246
            for (var b: u32 = 0u; b < ${workgroupSize}u; b++) {
247
              output_value += workgroup_shared[workgroup_shared_offset];
248
              workgroup_shared_offset += ${outputNumber};
249
            }
250
            ${output.setByIndices(`${output.type.indices}(batch, row, col + local_id.x)`, 'output_value')};
251
          }
252
        }`;
253
  };
254
  return {
255
    name: 'MatMulNBits',
256
    shaderCache: {
257
      hint: `${attributes.blockSize};${attributes.bits};${aComponents};${bComponents};${components};${outputNumber};${workgroupSize}`,
258
      inputDependencies: Array(inputs.length).fill('rank'),
259
    },
260
    getRunData: () => ({
261
      outputs: [{ dims: outputShape, dataType }],
262
      dispatchGroup: { x: dispatchSize },
263
      programUniforms,
264
    }),
265
    getShaderSource,
266
  };
267
};
268

269
export const matMulNBits = (context: ComputeContext, attributes: MatMulNBitsAttributes): void => {
270
  validateInputs(context.inputs, attributes);
271
  context.compute(createMatMulNBitsProgramInfo(context.inputs, attributes));
272
};
273

274
export const parseMatMulNBitsAttributes = (attributes: Record<string, unknown>): MatMulNBitsAttributes =>
275
  createAttributeWithCacheKey(attributes as Omit<MatMulNBitsAttributes, keyof AttributeWithCacheKey>);
276

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.