gpt4all

Форк
0
/
embeddings.cpp 
193 строки · 4.9 Кб
1
#include "embeddings.h"
2

3
#include <QFile>
4
#include <QFileInfo>
5
#include <QDebug>
6

7
#include "mysettings.h"
8
#include "hnswlib/hnswlib.h"
9

10
#define EMBEDDINGS_VERSION 0
11

12
const int s_dim = 384;              // Dimension of the elements
13
const int s_ef_construction = 200;  // Controls index search speed/build speed tradeoff
14
const int s_M = 16;                 // Tightly connected with internal dimensionality of the data
15
                                    // strongly affects the memory consumption
16

17
Embeddings::Embeddings(QObject *parent)
18
    : QObject(parent)
19
    , m_space(nullptr)
20
    , m_hnsw(nullptr)
21
{
22
    m_filePath = MySettings::globalInstance()->modelPath()
23
        + QString("embeddings_v%1.dat").arg(EMBEDDINGS_VERSION);
24
}
25

26
Embeddings::~Embeddings()
27
{
28
    delete m_hnsw;
29
    m_hnsw = nullptr;
30
    delete m_space;
31
    m_space = nullptr;
32
}
33

34
bool Embeddings::load()
35
{
36
    QFileInfo info(m_filePath);
37
    if (!info.exists()) {
38
        qWarning() << "ERROR: loading embeddings file does not exist" << m_filePath;
39
        return false;
40
    }
41

42
    if (!info.isReadable()) {
43
        qWarning() << "ERROR: loading embeddings file is not readable" << m_filePath;
44
        return false;
45
    }
46

47
    if (!info.isWritable()) {
48
        qWarning() << "ERROR: loading embeddings file is not writeable" << m_filePath;
49
        return false;
50
    }
51

52
    try {
53
        m_space = new hnswlib::InnerProductSpace(s_dim);
54
        m_hnsw = new hnswlib::HierarchicalNSW<float>(m_space, m_filePath.toStdString(), s_M, s_ef_construction);
55
    } catch (const std::exception &e) {
56
        qWarning() << "ERROR: could not load hnswlib index:" << e.what();
57
        return false;
58
    }
59
    return isLoaded();
60
}
61

62
bool Embeddings::load(qint64 maxElements)
63
{
64
    try {
65
        m_space = new hnswlib::InnerProductSpace(s_dim);
66
        m_hnsw = new hnswlib::HierarchicalNSW<float>(m_space, maxElements, s_M, s_ef_construction);
67
    } catch (const std::exception &e) {
68
        qWarning() << "ERROR: could not create hnswlib index:" << e.what();
69
        return false;
70
    }
71
    return isLoaded();
72
}
73

74
bool Embeddings::save()
75
{
76
    if (!isLoaded())
77
        return false;
78
    try {
79
        m_hnsw->saveIndex(m_filePath.toStdString());
80
    } catch (const std::exception &e) {
81
        qWarning() << "ERROR: could not save hnswlib index:" << e.what();
82
        return false;
83
    }
84
    return true;
85
}
86

87
bool Embeddings::isLoaded() const
88
{
89
    return m_hnsw != nullptr;
90
}
91

92
bool Embeddings::fileExists() const
93
{
94
    QFileInfo info(m_filePath);
95
    return info.exists();
96
}
97

98
bool Embeddings::resize(qint64 size)
99
{
100
    if (!isLoaded()) {
101
        qWarning() << "ERROR: attempting to resize an embedding when the embeddings are not open!";
102
        return false;
103
    }
104

105
    Q_ASSERT(m_hnsw);
106
    try {
107
        m_hnsw->resizeIndex(size);
108
    } catch (const std::exception &e) {
109
        qWarning() << "ERROR: could not resize hnswlib index:" << e.what();
110
        return false;
111
    }
112
    return true;
113
}
114

115
bool Embeddings::add(const std::vector<float> &embedding, qint64 label)
116
{
117
    if (!isLoaded()) {
118
        bool success = load(500);
119
        if (!success) {
120
            qWarning() << "ERROR: attempting to add an embedding when the embeddings are not open!";
121
            return false;
122
        }
123
    }
124

125
    Q_ASSERT(m_hnsw);
126
    if (m_hnsw->cur_element_count + 1 > m_hnsw->max_elements_) {
127
        if (!resize(m_hnsw->max_elements_ + 500)) {
128
            return false;
129
        }
130
    }
131

132
    if (embedding.empty())
133
        return false;
134

135
    try {
136
        m_hnsw->addPoint(embedding.data(), label, false);
137
    } catch (const std::exception &e) {
138
        qWarning() << "ERROR: could not add embedding to hnswlib index:" << e.what();
139
        return false;
140
    }
141
    return true;
142
}
143

144
void Embeddings::remove(qint64 label)
145
{
146
    if (!isLoaded()) {
147
        qWarning() << "ERROR: attempting to remove an embedding when the embeddings are not open!";
148
        return;
149
    }
150

151
    Q_ASSERT(m_hnsw);
152
    try {
153
        m_hnsw->markDelete(label);
154
    } catch (const std::exception &e) {
155
        qWarning() << "ERROR: could not add remove embedding from hnswlib index:" << e.what();
156
    }
157
}
158

159
void Embeddings::clear()
160
{
161
    delete m_hnsw;
162
    m_hnsw = nullptr;
163
    delete m_space;
164
    m_space = nullptr;
165
}
166

167
std::vector<qint64> Embeddings::search(const std::vector<float> &embedding, int K)
168
{
169
    if (!isLoaded())
170
        return {};
171

172
    Q_ASSERT(m_hnsw);
173
    std::priority_queue<std::pair<float, hnswlib::labeltype>> result;
174
    try {
175
        result = m_hnsw->searchKnn(embedding.data(), K);
176
    } catch (const std::exception &e) {
177
        qWarning() << "ERROR: could not search hnswlib index:" << e.what();
178
        return {};
179
    }
180

181
    std::vector<qint64> neighbors;
182
    neighbors.reserve(K);
183

184
    while(!result.empty()) {
185
        neighbors.push_back(result.top().second);
186
        result.pop();
187
    }
188

189
    // Reverse the neighbors, as the top of the priority queue is the farthest neighbor.
190
    std::reverse(neighbors.begin(), neighbors.end());
191

192
    return neighbors;
193
}
194

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.