llvm-project
239 строк · 9.1 Кб
1//===- quant.c - Test of Quant dialect C API ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4// Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9
10// RUN: mlir-capi-quant-test 2>&1 | FileCheck %s
11
12#include "mlir-c/Dialect/Quant.h"13#include "mlir-c/BuiltinTypes.h"14#include "mlir-c/IR.h"15
16#include <assert.h>17#include <inttypes.h>18#include <stdio.h>19#include <stdlib.h>20
21// CHECK-LABEL: testTypeHierarchy
22static void testTypeHierarchy(MlirContext ctx) {23fprintf(stderr, "testTypeHierarchy\n");24
25MlirType i8 = mlirIntegerTypeGet(ctx, 8);26MlirType any = mlirTypeParseGet(27ctx, mlirStringRefCreateFromCString("!quant.any<i8<-8:7>:f32>"));28MlirType uniform =29mlirTypeParseGet(ctx, mlirStringRefCreateFromCString(30"!quant.uniform<i8<-8:7>:f32, 0.99872:127>"));31MlirType perAxis = mlirTypeParseGet(32ctx, mlirStringRefCreateFromCString(33"!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>"));34MlirType calibrated = mlirTypeParseGet(35ctx,36mlirStringRefCreateFromCString("!quant.calibrated<f32<-0.998:1.2321>>"));37
38// The parser itself is checked in C++ dialect tests.39assert(!mlirTypeIsNull(any) && "couldn't parse AnyQuantizedType");40assert(!mlirTypeIsNull(uniform) && "couldn't parse UniformQuantizedType");41assert(!mlirTypeIsNull(perAxis) &&42"couldn't parse UniformQuantizedPerAxisType");43assert(!mlirTypeIsNull(calibrated) &&44"couldn't parse CalibratedQuantizedType");45
46// CHECK: i8 isa QuantizedType: 047fprintf(stderr, "i8 isa QuantizedType: %d\n", mlirTypeIsAQuantizedType(i8));48// CHECK: any isa QuantizedType: 149fprintf(stderr, "any isa QuantizedType: %d\n", mlirTypeIsAQuantizedType(any));50// CHECK: uniform isa QuantizedType: 151fprintf(stderr, "uniform isa QuantizedType: %d\n",52mlirTypeIsAQuantizedType(uniform));53// CHECK: perAxis isa QuantizedType: 154fprintf(stderr, "perAxis isa QuantizedType: %d\n",55mlirTypeIsAQuantizedType(perAxis));56// CHECK: calibrated isa QuantizedType: 157fprintf(stderr, "calibrated isa QuantizedType: %d\n",58mlirTypeIsAQuantizedType(calibrated));59
60// CHECK: any isa AnyQuantizedType: 161fprintf(stderr, "any isa AnyQuantizedType: %d\n",62mlirTypeIsAAnyQuantizedType(any));63// CHECK: uniform isa UniformQuantizedType: 164fprintf(stderr, "uniform isa UniformQuantizedType: %d\n",65mlirTypeIsAUniformQuantizedType(uniform));66// CHECK: perAxis isa UniformQuantizedPerAxisType: 167fprintf(stderr, "perAxis isa UniformQuantizedPerAxisType: %d\n",68mlirTypeIsAUniformQuantizedPerAxisType(perAxis));69// CHECK: calibrated isa CalibratedQuantizedType: 170fprintf(stderr, "calibrated isa CalibratedQuantizedType: %d\n",71mlirTypeIsACalibratedQuantizedType(calibrated));72
73// CHECK: perAxis isa UniformQuantizedType: 074fprintf(stderr, "perAxis isa UniformQuantizedType: %d\n",75mlirTypeIsAUniformQuantizedType(perAxis));76// CHECK: uniform isa CalibratedQuantizedType: 077fprintf(stderr, "uniform isa CalibratedQuantizedType: %d\n",78mlirTypeIsACalibratedQuantizedType(uniform));79fprintf(stderr, "\n");80}
81
82// CHECK-LABEL: testAnyQuantizedType
83void testAnyQuantizedType(MlirContext ctx) {84fprintf(stderr, "testAnyQuantizedType\n");85
86MlirType anyParsed = mlirTypeParseGet(87ctx, mlirStringRefCreateFromCString("!quant.any<i8<-8:7>:f32>"));88
89MlirType i8 = mlirIntegerTypeGet(ctx, 8);90MlirType f32 = mlirF32TypeGet(ctx);91MlirType any =92mlirAnyQuantizedTypeGet(mlirQuantizedTypeGetSignedFlag(), i8, f32, -8, 7);93
94// CHECK: flags: 195fprintf(stderr, "flags: %u\n", mlirQuantizedTypeGetFlags(any));96// CHECK: signed: 197fprintf(stderr, "signed: %u\n", mlirQuantizedTypeIsSigned(any));98// CHECK: storage type: i899fprintf(stderr, "storage type: ");100mlirTypeDump(mlirQuantizedTypeGetStorageType(any));101fprintf(stderr, "\n");102// CHECK: expressed type: f32103fprintf(stderr, "expressed type: ");104mlirTypeDump(mlirQuantizedTypeGetExpressedType(any));105fprintf(stderr, "\n");106// CHECK: storage min: -8107fprintf(stderr, "storage min: %" PRId64 "\n",108mlirQuantizedTypeGetStorageTypeMin(any));109// CHECK: storage max: 7110fprintf(stderr, "storage max: %" PRId64 "\n",111mlirQuantizedTypeGetStorageTypeMax(any));112// CHECK: storage width: 8113fprintf(stderr, "storage width: %u\n",114mlirQuantizedTypeGetStorageTypeIntegralWidth(any));115// CHECK: quantized element type: !quant.any<i8<-8:7>:f32>116fprintf(stderr, "quantized element type: ");117mlirTypeDump(mlirQuantizedTypeGetQuantizedElementType(any));118fprintf(stderr, "\n");119
120// CHECK: equal: 1121fprintf(stderr, "equal: %d\n", mlirTypeEqual(anyParsed, any));122// CHECK: !quant.any<i8<-8:7>:f32>123mlirTypeDump(any);124fprintf(stderr, "\n\n");125}
126
127// CHECK-LABEL: testUniformType
128void testUniformType(MlirContext ctx) {129fprintf(stderr, "testUniformType\n");130
131MlirType uniformParsed =132mlirTypeParseGet(ctx, mlirStringRefCreateFromCString(133"!quant.uniform<i8<-8:7>:f32, 0.99872:127>"));134
135MlirType i8 = mlirIntegerTypeGet(ctx, 8);136MlirType f32 = mlirF32TypeGet(ctx);137MlirType uniform = mlirUniformQuantizedTypeGet(138mlirQuantizedTypeGetSignedFlag(), i8, f32, 0.99872, 127, -8, 7);139
140// CHECK: scale: 0.998720141fprintf(stderr, "scale: %lf\n", mlirUniformQuantizedTypeGetScale(uniform));142// CHECK: zero point: 127143fprintf(stderr, "zero point: %" PRId64 "\n",144mlirUniformQuantizedTypeGetZeroPoint(uniform));145// CHECK: fixed point: 0146fprintf(stderr, "fixed point: %d\n",147mlirUniformQuantizedTypeIsFixedPoint(uniform));148
149// CHECK: equal: 1150fprintf(stderr, "equal: %d\n", mlirTypeEqual(uniform, uniformParsed));151// CHECK: !quant.uniform<i8<-8:7>:f32, 9.987200e-01:127>152mlirTypeDump(uniform);153fprintf(stderr, "\n\n");154}
155
156// CHECK-LABEL: testUniformPerAxisType
157void testUniformPerAxisType(MlirContext ctx) {158fprintf(stderr, "testUniformPerAxisType\n");159
160MlirType perAxisParsed = mlirTypeParseGet(161ctx, mlirStringRefCreateFromCString(162"!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>"));163
164MlirType i8 = mlirIntegerTypeGet(ctx, 8);165MlirType f32 = mlirF32TypeGet(ctx);166double scales[] = {200.0, 0.99872};167int64_t zeroPoints[] = {0, 120};168MlirType perAxis = mlirUniformQuantizedPerAxisTypeGet(169mlirQuantizedTypeGetSignedFlag(), i8, f32,170/*nDims=*/2, scales, zeroPoints,171/*quantizedDimension=*/1,172mlirQuantizedTypeGetDefaultMinimumForInteger(/*isSigned=*/true,173/*integralWidth=*/8),174mlirQuantizedTypeGetDefaultMaximumForInteger(/*isSigned=*/true,175/*integralWidth=*/8));176
177// CHECK: num dims: 2178fprintf(stderr, "num dims: %" PRIdPTR "\n",179mlirUniformQuantizedPerAxisTypeGetNumDims(perAxis));180// CHECK: scale 0: 200.000000181fprintf(stderr, "scale 0: %lf\n",182mlirUniformQuantizedPerAxisTypeGetScale(perAxis, 0));183// CHECK: scale 1: 0.998720184fprintf(stderr, "scale 1: %lf\n",185mlirUniformQuantizedPerAxisTypeGetScale(perAxis, 1));186// CHECK: zero point 0: 0187fprintf(stderr, "zero point 0: %" PRId64 "\n",188mlirUniformQuantizedPerAxisTypeGetZeroPoint(perAxis, 0));189// CHECK: zero point 1: 120190fprintf(stderr, "zero point 1: %" PRId64 "\n",191mlirUniformQuantizedPerAxisTypeGetZeroPoint(perAxis, 1));192// CHECK: quantized dim: 1193fprintf(stderr, "quantized dim: %" PRId32 "\n",194mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(perAxis));195// CHECK: fixed point: 0196fprintf(stderr, "fixed point: %d\n",197mlirUniformQuantizedPerAxisTypeIsFixedPoint(perAxis));198
199// CHECK: equal: 1200fprintf(stderr, "equal: %d\n", mlirTypeEqual(perAxis, perAxisParsed));201// CHECK: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01:120}>202mlirTypeDump(perAxis);203fprintf(stderr, "\n\n");204}
205
206// CHECK-LABEL: testCalibratedType
207void testCalibratedType(MlirContext ctx) {208fprintf(stderr, "testCalibratedType\n");209
210MlirType calibratedParsed = mlirTypeParseGet(211ctx,212mlirStringRefCreateFromCString("!quant.calibrated<f32<-0.998:1.2321>>"));213
214MlirType f32 = mlirF32TypeGet(ctx);215MlirType calibrated = mlirCalibratedQuantizedTypeGet(f32, -0.998, 1.2321);216
217// CHECK: min: -0.998000218fprintf(stderr, "min: %lf\n", mlirCalibratedQuantizedTypeGetMin(calibrated));219// CHECK: max: 1.232100220fprintf(stderr, "max: %lf\n", mlirCalibratedQuantizedTypeGetMax(calibrated));221
222// CHECK: equal: 1223fprintf(stderr, "equal: %d\n", mlirTypeEqual(calibrated, calibratedParsed));224// CHECK: !quant.calibrated<f32<-0.998:1.232100e+00>>225mlirTypeDump(calibrated);226fprintf(stderr, "\n\n");227}
228
229int main(void) {230MlirContext ctx = mlirContextCreate();231mlirDialectHandleRegisterDialect(mlirGetDialectHandle__quant__(), ctx);232testTypeHierarchy(ctx);233testAnyQuantizedType(ctx);234testUniformType(ctx);235testUniformPerAxisType(ctx);236testCalibratedType(ctx);237mlirContextDestroy(ctx);238return EXIT_SUCCESS;239}
240