13
#pragma warning(disable: 4244 4267)
16
constexpr float MAX_QUANTIZATION_REFERENCE_ERROR = 0.0001f;
17
constexpr float MAX_QUANTIZATION_TOTAL_ERROR = 0.002f;
18
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_TERNARY = 0.01f;
19
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_2BITS = 0.0075f;
20
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS = 0.0040f;
21
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS = 0.0050f;
22
constexpr float MAX_DOT_PRODUCT_ERROR = 0.02f;
23
constexpr float MAX_DOT_PRODUCT_ERROR_LOWBIT = 0.04f;
24
constexpr float MAX_DOT_PRODUCT_ERROR_TERNARY = 0.15f;
26
static const char* RESULT_STR[] = {"ok", "FAILED"};
30
static void generate_data(float offset, size_t n, float * dst) {
31
for (size_t i = 0; i < n; i++) {
32
dst[i] = 0.1 + 2*cosf(i + offset);
37
static float array_rmse(const float * a1, const float * a2, size_t n) {
39
for (size_t i = 0; i < n; i++) {
40
double diff = a1[i] - a2[i];
43
return sqrtf(sum) / n;
47
static float total_quantization_error(ggml_type_traits_t & qfns, size_t test_size, const float * test_data) {
48
std::vector<uint8_t> tmp_q(2*test_size);
49
std::vector<float> tmp_out(test_size);
51
qfns.from_float(test_data, tmp_q.data(), test_size);
52
qfns.to_float(tmp_q.data(), tmp_out.data(), test_size);
53
return array_rmse(test_data, tmp_out.data(), test_size);
57
static float reference_quantization_error(ggml_type_traits_t & qfns, size_t test_size, const float * test_data) {
58
std::vector<uint8_t> tmp_q(2*test_size);
59
std::vector<float> tmp_out(test_size);
60
std::vector<float> tmp_out_ref(test_size);
62
qfns.from_float(test_data, tmp_q.data(), test_size);
63
qfns.to_float(tmp_q.data(), tmp_out.data(), test_size);
65
qfns.from_float_ref(test_data, tmp_q.data(), test_size);
66
qfns.to_float(tmp_q.data(), tmp_out_ref.data(), test_size);
68
return array_rmse(tmp_out.data(), tmp_out_ref.data(), test_size);
71
static float dot_product(const float * a1, const float * a2, size_t test_size) {
73
for (size_t i = 0; i < test_size; i++) {
80
static float dot_product_error(
81
ggml_type_traits_t & qfns, size_t test_size, const float * test_data1, const float *test_data2
83
std::vector<uint8_t> tmp_q1(2*test_size);
84
std::vector<uint8_t> tmp_q2(2*test_size);
86
auto vdot = ggml_internal_get_type_traits(qfns.vec_dot_type);
88
qfns.from_float(test_data1, tmp_q1.data(), test_size);
89
vdot.from_float(test_data2, tmp_q2.data(), test_size);
91
float result = INFINITY;
92
qfns.vec_dot(test_size, &result, 0, tmp_q1.data(), 0, tmp_q2.data(), 0, 1);
94
const float dot_ref = dot_product(test_data1, test_data2, test_size);
96
return fabsf(result - dot_ref) / test_size;
99
int main(int argc, char * argv[]) {
100
bool verbose = false;
101
const size_t test_size = 32 * 128;
104
for (int i = 1; i < argc; i++) {
110
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
115
std::vector<float> test_data(test_size);
116
std::vector<float> test_data2(test_size);
118
generate_data(0.0, test_data.size(), test_data.data());
119
generate_data(1.0, test_data2.size(), test_data2.data());
122
struct ggml_init_params ggml_params = {
127
struct ggml_context * ctx = ggml_init(ggml_params);
132
for (int i = 0; i < GGML_TYPE_COUNT; i++) {
133
ggml_type type = (ggml_type) i;
134
ggml_type_traits_t qfns = ggml_internal_get_type_traits(type);
137
if (qfns.blck_size == 0) {
141
const ggml_type ei = (ggml_type)i;
143
printf("Testing %s\n", ggml_type_name((ggml_type) i));
144
ggml_quantize_init(ei);
146
if (qfns.from_float && qfns.to_float) {
147
const float total_error = total_quantization_error(qfns, test_size, test_data.data());
148
const float max_quantization_error =
149
type == GGML_TYPE_TQ1_0 ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY :
150
type == GGML_TYPE_TQ2_0 ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY :
151
type == GGML_TYPE_Q2_K ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS :
152
type == GGML_TYPE_IQ2_S ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS :
153
type == GGML_TYPE_Q3_K ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS :
154
type == GGML_TYPE_IQ3_S ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS :
155
type == GGML_TYPE_IQ3_XXS ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS : MAX_QUANTIZATION_TOTAL_ERROR;
156
failed = !(total_error < max_quantization_error);
157
num_failed += failed;
158
if (failed || verbose) {
159
printf("%5s absolute quantization error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], total_error);
162
const float reference_error = reference_quantization_error(qfns, test_size, test_data.data());
163
failed = !(reference_error < MAX_QUANTIZATION_REFERENCE_ERROR);
164
num_failed += failed;
165
if (failed || verbose) {
166
printf("%5s reference implementation error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], reference_error);
169
const float vec_dot_error = dot_product_error(qfns, test_size, test_data.data(), test_data2.data());
170
const float max_allowed_error = type == GGML_TYPE_Q2_K || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ2_XXS ||
171
type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S || type == GGML_TYPE_IQ2_S
172
? MAX_DOT_PRODUCT_ERROR_LOWBIT
173
: type == GGML_TYPE_TQ1_0 || type == GGML_TYPE_TQ2_0
174
? MAX_DOT_PRODUCT_ERROR_TERNARY
175
: MAX_DOT_PRODUCT_ERROR;
176
failed = !(vec_dot_error < max_allowed_error);
177
num_failed += failed;
178
if (failed || verbose) {
179
printf("%5s dot product error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], vec_dot_error);
184
if (num_failed || verbose) {
185
printf("%d tests failed\n", num_failed);
190
return num_failed > 0;