llvm-project

Форк
0
239 строк · 9.1 Кб
1
//===- quant.c - Test of Quant dialect C API ------------------------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4
// Exceptions.
5
// See https://llvm.org/LICENSE.txt for license information.
6
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7
//
8
//===----------------------------------------------------------------------===//
9

10
// RUN: mlir-capi-quant-test 2>&1 | FileCheck %s
11

12
#include "mlir-c/Dialect/Quant.h"
13
#include "mlir-c/BuiltinTypes.h"
14
#include "mlir-c/IR.h"
15

16
#include <assert.h>
17
#include <inttypes.h>
18
#include <stdio.h>
19
#include <stdlib.h>
20

21
// CHECK-LABEL: testTypeHierarchy
22
static void testTypeHierarchy(MlirContext ctx) {
23
  fprintf(stderr, "testTypeHierarchy\n");
24

25
  MlirType i8 = mlirIntegerTypeGet(ctx, 8);
26
  MlirType any = mlirTypeParseGet(
27
      ctx, mlirStringRefCreateFromCString("!quant.any<i8<-8:7>:f32>"));
28
  MlirType uniform =
29
      mlirTypeParseGet(ctx, mlirStringRefCreateFromCString(
30
                                "!quant.uniform<i8<-8:7>:f32, 0.99872:127>"));
31
  MlirType perAxis = mlirTypeParseGet(
32
      ctx, mlirStringRefCreateFromCString(
33
               "!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>"));
34
  MlirType calibrated = mlirTypeParseGet(
35
      ctx,
36
      mlirStringRefCreateFromCString("!quant.calibrated<f32<-0.998:1.2321>>"));
37

38
  // The parser itself is checked in C++ dialect tests.
39
  assert(!mlirTypeIsNull(any) && "couldn't parse AnyQuantizedType");
40
  assert(!mlirTypeIsNull(uniform) && "couldn't parse UniformQuantizedType");
41
  assert(!mlirTypeIsNull(perAxis) &&
42
         "couldn't parse UniformQuantizedPerAxisType");
43
  assert(!mlirTypeIsNull(calibrated) &&
44
         "couldn't parse CalibratedQuantizedType");
45

46
  // CHECK: i8 isa QuantizedType: 0
47
  fprintf(stderr, "i8 isa QuantizedType: %d\n", mlirTypeIsAQuantizedType(i8));
48
  // CHECK: any isa QuantizedType: 1
49
  fprintf(stderr, "any isa QuantizedType: %d\n", mlirTypeIsAQuantizedType(any));
50
  // CHECK: uniform isa QuantizedType: 1
51
  fprintf(stderr, "uniform isa QuantizedType: %d\n",
52
          mlirTypeIsAQuantizedType(uniform));
53
  // CHECK: perAxis isa QuantizedType: 1
54
  fprintf(stderr, "perAxis isa QuantizedType: %d\n",
55
          mlirTypeIsAQuantizedType(perAxis));
56
  // CHECK: calibrated isa QuantizedType: 1
57
  fprintf(stderr, "calibrated isa QuantizedType: %d\n",
58
          mlirTypeIsAQuantizedType(calibrated));
59

60
  // CHECK: any isa AnyQuantizedType: 1
61
  fprintf(stderr, "any isa AnyQuantizedType: %d\n",
62
          mlirTypeIsAAnyQuantizedType(any));
63
  // CHECK: uniform isa UniformQuantizedType: 1
64
  fprintf(stderr, "uniform isa UniformQuantizedType: %d\n",
65
          mlirTypeIsAUniformQuantizedType(uniform));
66
  // CHECK: perAxis isa UniformQuantizedPerAxisType: 1
67
  fprintf(stderr, "perAxis isa UniformQuantizedPerAxisType: %d\n",
68
          mlirTypeIsAUniformQuantizedPerAxisType(perAxis));
69
  // CHECK: calibrated isa CalibratedQuantizedType: 1
70
  fprintf(stderr, "calibrated isa CalibratedQuantizedType: %d\n",
71
          mlirTypeIsACalibratedQuantizedType(calibrated));
72

73
  // CHECK: perAxis isa UniformQuantizedType: 0
74
  fprintf(stderr, "perAxis isa UniformQuantizedType: %d\n",
75
          mlirTypeIsAUniformQuantizedType(perAxis));
76
  // CHECK: uniform isa CalibratedQuantizedType: 0
77
  fprintf(stderr, "uniform isa CalibratedQuantizedType: %d\n",
78
          mlirTypeIsACalibratedQuantizedType(uniform));
79
  fprintf(stderr, "\n");
80
}
81

82
// CHECK-LABEL: testAnyQuantizedType
83
void testAnyQuantizedType(MlirContext ctx) {
84
  fprintf(stderr, "testAnyQuantizedType\n");
85

86
  MlirType anyParsed = mlirTypeParseGet(
87
      ctx, mlirStringRefCreateFromCString("!quant.any<i8<-8:7>:f32>"));
88

89
  MlirType i8 = mlirIntegerTypeGet(ctx, 8);
90
  MlirType f32 = mlirF32TypeGet(ctx);
91
  MlirType any =
92
      mlirAnyQuantizedTypeGet(mlirQuantizedTypeGetSignedFlag(), i8, f32, -8, 7);
93

94
  // CHECK: flags: 1
95
  fprintf(stderr, "flags: %u\n", mlirQuantizedTypeGetFlags(any));
96
  // CHECK: signed: 1
97
  fprintf(stderr, "signed: %u\n", mlirQuantizedTypeIsSigned(any));
98
  // CHECK: storage type: i8
99
  fprintf(stderr, "storage type: ");
100
  mlirTypeDump(mlirQuantizedTypeGetStorageType(any));
101
  fprintf(stderr, "\n");
102
  // CHECK: expressed type: f32
103
  fprintf(stderr, "expressed type: ");
104
  mlirTypeDump(mlirQuantizedTypeGetExpressedType(any));
105
  fprintf(stderr, "\n");
106
  // CHECK: storage min: -8
107
  fprintf(stderr, "storage min: %" PRId64 "\n",
108
          mlirQuantizedTypeGetStorageTypeMin(any));
109
  // CHECK: storage max: 7
110
  fprintf(stderr, "storage max: %" PRId64 "\n",
111
          mlirQuantizedTypeGetStorageTypeMax(any));
112
  // CHECK: storage width: 8
113
  fprintf(stderr, "storage width: %u\n",
114
          mlirQuantizedTypeGetStorageTypeIntegralWidth(any));
115
  // CHECK: quantized element type: !quant.any<i8<-8:7>:f32>
116
  fprintf(stderr, "quantized element type: ");
117
  mlirTypeDump(mlirQuantizedTypeGetQuantizedElementType(any));
118
  fprintf(stderr, "\n");
119

120
  // CHECK: equal: 1
121
  fprintf(stderr, "equal: %d\n", mlirTypeEqual(anyParsed, any));
122
  // CHECK: !quant.any<i8<-8:7>:f32>
123
  mlirTypeDump(any);
124
  fprintf(stderr, "\n\n");
125
}
126

127
// CHECK-LABEL: testUniformType
128
void testUniformType(MlirContext ctx) {
129
  fprintf(stderr, "testUniformType\n");
130

131
  MlirType uniformParsed =
132
      mlirTypeParseGet(ctx, mlirStringRefCreateFromCString(
133
                                "!quant.uniform<i8<-8:7>:f32, 0.99872:127>"));
134

135
  MlirType i8 = mlirIntegerTypeGet(ctx, 8);
136
  MlirType f32 = mlirF32TypeGet(ctx);
137
  MlirType uniform = mlirUniformQuantizedTypeGet(
138
      mlirQuantizedTypeGetSignedFlag(), i8, f32, 0.99872, 127, -8, 7);
139

140
  // CHECK: scale: 0.998720
141
  fprintf(stderr, "scale: %lf\n", mlirUniformQuantizedTypeGetScale(uniform));
142
  // CHECK: zero point: 127
143
  fprintf(stderr, "zero point: %" PRId64 "\n",
144
          mlirUniformQuantizedTypeGetZeroPoint(uniform));
145
  // CHECK: fixed point: 0
146
  fprintf(stderr, "fixed point: %d\n",
147
          mlirUniformQuantizedTypeIsFixedPoint(uniform));
148

149
  // CHECK: equal: 1
150
  fprintf(stderr, "equal: %d\n", mlirTypeEqual(uniform, uniformParsed));
151
  // CHECK: !quant.uniform<i8<-8:7>:f32, 9.987200e-01:127>
152
  mlirTypeDump(uniform);
153
  fprintf(stderr, "\n\n");
154
}
155

156
// CHECK-LABEL: testUniformPerAxisType
157
void testUniformPerAxisType(MlirContext ctx) {
158
  fprintf(stderr, "testUniformPerAxisType\n");
159

160
  MlirType perAxisParsed = mlirTypeParseGet(
161
      ctx, mlirStringRefCreateFromCString(
162
               "!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>"));
163

164
  MlirType i8 = mlirIntegerTypeGet(ctx, 8);
165
  MlirType f32 = mlirF32TypeGet(ctx);
166
  double scales[] = {200.0, 0.99872};
167
  int64_t zeroPoints[] = {0, 120};
168
  MlirType perAxis = mlirUniformQuantizedPerAxisTypeGet(
169
      mlirQuantizedTypeGetSignedFlag(), i8, f32,
170
      /*nDims=*/2, scales, zeroPoints,
171
      /*quantizedDimension=*/1,
172
      mlirQuantizedTypeGetDefaultMinimumForInteger(/*isSigned=*/true,
173
                                                   /*integralWidth=*/8),
174
      mlirQuantizedTypeGetDefaultMaximumForInteger(/*isSigned=*/true,
175
                                                   /*integralWidth=*/8));
176

177
  // CHECK: num dims: 2
178
  fprintf(stderr, "num dims: %" PRIdPTR "\n",
179
          mlirUniformQuantizedPerAxisTypeGetNumDims(perAxis));
180
  // CHECK: scale 0: 200.000000
181
  fprintf(stderr, "scale 0: %lf\n",
182
          mlirUniformQuantizedPerAxisTypeGetScale(perAxis, 0));
183
  // CHECK: scale 1: 0.998720
184
  fprintf(stderr, "scale 1: %lf\n",
185
          mlirUniformQuantizedPerAxisTypeGetScale(perAxis, 1));
186
  // CHECK: zero point 0: 0
187
  fprintf(stderr, "zero point 0: %" PRId64 "\n",
188
          mlirUniformQuantizedPerAxisTypeGetZeroPoint(perAxis, 0));
189
  // CHECK: zero point 1: 120
190
  fprintf(stderr, "zero point 1: %" PRId64 "\n",
191
          mlirUniformQuantizedPerAxisTypeGetZeroPoint(perAxis, 1));
192
  // CHECK: quantized dim: 1
193
  fprintf(stderr, "quantized dim: %" PRId32 "\n",
194
          mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(perAxis));
195
  // CHECK: fixed point: 0
196
  fprintf(stderr, "fixed point: %d\n",
197
          mlirUniformQuantizedPerAxisTypeIsFixedPoint(perAxis));
198

199
  // CHECK: equal: 1
200
  fprintf(stderr, "equal: %d\n", mlirTypeEqual(perAxis, perAxisParsed));
201
  // CHECK: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01:120}>
202
  mlirTypeDump(perAxis);
203
  fprintf(stderr, "\n\n");
204
}
205

206
// CHECK-LABEL: testCalibratedType
207
void testCalibratedType(MlirContext ctx) {
208
  fprintf(stderr, "testCalibratedType\n");
209

210
  MlirType calibratedParsed = mlirTypeParseGet(
211
      ctx,
212
      mlirStringRefCreateFromCString("!quant.calibrated<f32<-0.998:1.2321>>"));
213

214
  MlirType f32 = mlirF32TypeGet(ctx);
215
  MlirType calibrated = mlirCalibratedQuantizedTypeGet(f32, -0.998, 1.2321);
216

217
  // CHECK: min: -0.998000
218
  fprintf(stderr, "min: %lf\n", mlirCalibratedQuantizedTypeGetMin(calibrated));
219
  // CHECK: max: 1.232100
220
  fprintf(stderr, "max: %lf\n", mlirCalibratedQuantizedTypeGetMax(calibrated));
221

222
  // CHECK: equal: 1
223
  fprintf(stderr, "equal: %d\n", mlirTypeEqual(calibrated, calibratedParsed));
224
  // CHECK: !quant.calibrated<f32<-0.998:1.232100e+00>>
225
  mlirTypeDump(calibrated);
226
  fprintf(stderr, "\n\n");
227
}
228

229
int main(void) {
230
  MlirContext ctx = mlirContextCreate();
231
  mlirDialectHandleRegisterDialect(mlirGetDialectHandle__quant__(), ctx);
232
  testTypeHierarchy(ctx);
233
  testAnyQuantizedType(ctx);
234
  testUniformType(ctx);
235
  testUniformPerAxisType(ctx);
236
  testCalibratedType(ctx);
237
  mlirContextDestroy(ctx);
238
  return EXIT_SUCCESS;
239
}
240

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.