llama

Форк
0
/
test-quantize-fns.cpp 
191 строка · 7.0 Кб
1
// Unit tests for quantization specific functions - quantize, dequantize and dot product
2

3
#include "ggml.h"
4

5
#undef NDEBUG
6
#include <assert.h>
7
#include <math.h>
8
#include <stdio.h>
9
#include <string>
10
#include <vector>
11

12
#if defined(_MSC_VER)
13
#pragma warning(disable: 4244 4267) // possible loss of data
14
#endif
15

16
constexpr float MAX_QUANTIZATION_REFERENCE_ERROR = 0.0001f;
17
constexpr float MAX_QUANTIZATION_TOTAL_ERROR = 0.002f;
18
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_TERNARY = 0.01f;
19
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_2BITS = 0.0075f;
20
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS = 0.0040f;
21
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS = 0.0050f;
22
constexpr float MAX_DOT_PRODUCT_ERROR = 0.02f;
23
constexpr float MAX_DOT_PRODUCT_ERROR_LOWBIT = 0.04f;
24
constexpr float MAX_DOT_PRODUCT_ERROR_TERNARY = 0.15f;
25

26
static const char* RESULT_STR[] = {"ok", "FAILED"};
27

28

29
// Generate synthetic data
30
static void generate_data(float offset, size_t n, float * dst) {
31
    for (size_t i = 0; i < n; i++) {
32
        dst[i] = 0.1 + 2*cosf(i + offset);
33
    }
34
}
35

36
// Calculate RMSE between two float arrays
37
static float array_rmse(const float * a1, const float * a2, size_t n) {
38
    double sum = 0;
39
    for (size_t i = 0; i < n; i++) {
40
        double diff = a1[i] - a2[i];
41
        sum += diff * diff;
42
    }
43
    return sqrtf(sum) / n;
44
}
45

46
// Total quantization error on test data
47
static float total_quantization_error(ggml_type_traits_t & qfns, size_t test_size, const float * test_data) {
48
    std::vector<uint8_t> tmp_q(2*test_size);
49
    std::vector<float> tmp_out(test_size);
50

51
    qfns.from_float(test_data, tmp_q.data(), test_size);
52
    qfns.to_float(tmp_q.data(), tmp_out.data(), test_size);
53
    return array_rmse(test_data, tmp_out.data(), test_size);
54
}
55

56
// Total quantization error on test data
57
static float reference_quantization_error(ggml_type_traits_t & qfns, size_t test_size, const float * test_data) {
58
    std::vector<uint8_t> tmp_q(2*test_size);
59
    std::vector<float> tmp_out(test_size);
60
    std::vector<float> tmp_out_ref(test_size);
61

62
    qfns.from_float(test_data, tmp_q.data(), test_size);
63
    qfns.to_float(tmp_q.data(), tmp_out.data(), test_size);
64

65
    qfns.from_float_ref(test_data, tmp_q.data(), test_size);
66
    qfns.to_float(tmp_q.data(), tmp_out_ref.data(), test_size);
67

68
    return array_rmse(tmp_out.data(), tmp_out_ref.data(), test_size);
69
}
70

71
static float dot_product(const float * a1, const float * a2, size_t test_size) {
72
    double sum = 0;
73
    for (size_t i = 0; i < test_size; i++) {
74
        sum += a1[i] * a2[i];
75
    }
76
    return sum;
77
}
78

79
// Total dot product error
80
static float dot_product_error(
81
    ggml_type_traits_t & qfns, size_t test_size, const float * test_data1, const float *test_data2
82
) {
83
    std::vector<uint8_t> tmp_q1(2*test_size);
84
    std::vector<uint8_t> tmp_q2(2*test_size);
85

86
    auto vdot = ggml_internal_get_type_traits(qfns.vec_dot_type);
87

88
    qfns.from_float(test_data1, tmp_q1.data(), test_size);
89
    vdot.from_float(test_data2, tmp_q2.data(), test_size);
90

91
    float result = INFINITY;
92
    qfns.vec_dot(test_size, &result, 0, tmp_q1.data(), 0, tmp_q2.data(), 0, 1);
93

94
    const float dot_ref = dot_product(test_data1, test_data2, test_size);
95

96
    return fabsf(result - dot_ref) / test_size;
97
}
98

99
int main(int argc, char * argv[]) {
100
    bool verbose = false;
101
    const size_t test_size = 32 * 128;
102

103
    std::string arg;
104
    for (int i = 1; i < argc; i++) {
105
        arg = argv[i];
106

107
        if (arg == "-v") {
108
            verbose = true;
109
        } else {
110
            fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
111
            return 1;
112
        }
113
    }
114

115
    std::vector<float> test_data(test_size);
116
    std::vector<float> test_data2(test_size);
117

118
    generate_data(0.0, test_data.size(), test_data.data());
119
    generate_data(1.0, test_data2.size(), test_data2.data());
120

121
    // Initialize GGML, ensures float conversion tables are initialized
122
    struct ggml_init_params ggml_params = {
123
        /* .mem_size   = */ 1*1024,
124
        /* .mem_buffer = */ NULL,
125
        /* .no_alloc   = */ true,
126
    };
127
    struct ggml_context * ctx = ggml_init(ggml_params);
128

129
    int num_failed = 0;
130
    bool failed = false;
131

132
    for (int i = 0; i < GGML_TYPE_COUNT; i++) {
133
        ggml_type type = (ggml_type) i;
134
        ggml_type_traits_t qfns = ggml_internal_get_type_traits(type);
135

136
        // deprecated - skip
137
        if (qfns.blck_size == 0) {
138
            continue;
139
        }
140

141
        const ggml_type ei = (ggml_type)i;
142

143
        printf("Testing %s\n", ggml_type_name((ggml_type) i));
144
        ggml_quantize_init(ei);
145

146
        if (qfns.from_float && qfns.to_float) {
147
            const float total_error = total_quantization_error(qfns, test_size, test_data.data());
148
            const float max_quantization_error =
149
                type == GGML_TYPE_TQ1_0   ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY :
150
                type == GGML_TYPE_TQ2_0   ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY :
151
                type == GGML_TYPE_Q2_K    ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS :
152
                type == GGML_TYPE_IQ2_S   ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS :
153
                type == GGML_TYPE_Q3_K    ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS :
154
                type == GGML_TYPE_IQ3_S   ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS :
155
                type == GGML_TYPE_IQ3_XXS ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS : MAX_QUANTIZATION_TOTAL_ERROR;
156
            failed = !(total_error < max_quantization_error);
157
            num_failed += failed;
158
            if (failed || verbose) {
159
                printf("%5s absolute quantization error:    %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], total_error);
160
            }
161

162
            const float reference_error = reference_quantization_error(qfns, test_size, test_data.data());
163
            failed = !(reference_error < MAX_QUANTIZATION_REFERENCE_ERROR);
164
            num_failed += failed;
165
            if (failed || verbose) {
166
                printf("%5s reference implementation error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], reference_error);
167
            }
168

169
            const float vec_dot_error = dot_product_error(qfns, test_size, test_data.data(), test_data2.data());
170
            const float max_allowed_error = type == GGML_TYPE_Q2_K || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ2_XXS ||
171
                                            type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S || type == GGML_TYPE_IQ2_S
172
                                          ? MAX_DOT_PRODUCT_ERROR_LOWBIT
173
                                          : type == GGML_TYPE_TQ1_0 || type == GGML_TYPE_TQ2_0
174
                                          ? MAX_DOT_PRODUCT_ERROR_TERNARY
175
                                          : MAX_DOT_PRODUCT_ERROR;
176
            failed = !(vec_dot_error < max_allowed_error);
177
            num_failed += failed;
178
            if (failed || verbose) {
179
                printf("%5s dot product error:              %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], vec_dot_error);
180
            }
181
        }
182
    }
183

184
    if (num_failed || verbose) {
185
        printf("%d tests failed\n", num_failed);
186
    }
187

188
    ggml_free(ctx);
189

190
    return num_failed > 0;
191
}
192

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.