gpt4all
193 строки · 4.9 Кб
1#include "embeddings.h"
2
3#include <QFile>
4#include <QFileInfo>
5#include <QDebug>
6
7#include "mysettings.h"
8#include "hnswlib/hnswlib.h"
9
10#define EMBEDDINGS_VERSION 0
11
12const int s_dim = 384; // Dimension of the elements
13const int s_ef_construction = 200; // Controls index search speed/build speed tradeoff
14const int s_M = 16; // Tightly connected with internal dimensionality of the data
15// strongly affects the memory consumption
16
17Embeddings::Embeddings(QObject *parent)
18: QObject(parent)
19, m_space(nullptr)
20, m_hnsw(nullptr)
21{
22m_filePath = MySettings::globalInstance()->modelPath()
23+ QString("embeddings_v%1.dat").arg(EMBEDDINGS_VERSION);
24}
25
26Embeddings::~Embeddings()
27{
28delete m_hnsw;
29m_hnsw = nullptr;
30delete m_space;
31m_space = nullptr;
32}
33
34bool Embeddings::load()
35{
36QFileInfo info(m_filePath);
37if (!info.exists()) {
38qWarning() << "ERROR: loading embeddings file does not exist" << m_filePath;
39return false;
40}
41
42if (!info.isReadable()) {
43qWarning() << "ERROR: loading embeddings file is not readable" << m_filePath;
44return false;
45}
46
47if (!info.isWritable()) {
48qWarning() << "ERROR: loading embeddings file is not writeable" << m_filePath;
49return false;
50}
51
52try {
53m_space = new hnswlib::InnerProductSpace(s_dim);
54m_hnsw = new hnswlib::HierarchicalNSW<float>(m_space, m_filePath.toStdString(), s_M, s_ef_construction);
55} catch (const std::exception &e) {
56qWarning() << "ERROR: could not load hnswlib index:" << e.what();
57return false;
58}
59return isLoaded();
60}
61
62bool Embeddings::load(qint64 maxElements)
63{
64try {
65m_space = new hnswlib::InnerProductSpace(s_dim);
66m_hnsw = new hnswlib::HierarchicalNSW<float>(m_space, maxElements, s_M, s_ef_construction);
67} catch (const std::exception &e) {
68qWarning() << "ERROR: could not create hnswlib index:" << e.what();
69return false;
70}
71return isLoaded();
72}
73
74bool Embeddings::save()
75{
76if (!isLoaded())
77return false;
78try {
79m_hnsw->saveIndex(m_filePath.toStdString());
80} catch (const std::exception &e) {
81qWarning() << "ERROR: could not save hnswlib index:" << e.what();
82return false;
83}
84return true;
85}
86
87bool Embeddings::isLoaded() const
88{
89return m_hnsw != nullptr;
90}
91
92bool Embeddings::fileExists() const
93{
94QFileInfo info(m_filePath);
95return info.exists();
96}
97
98bool Embeddings::resize(qint64 size)
99{
100if (!isLoaded()) {
101qWarning() << "ERROR: attempting to resize an embedding when the embeddings are not open!";
102return false;
103}
104
105Q_ASSERT(m_hnsw);
106try {
107m_hnsw->resizeIndex(size);
108} catch (const std::exception &e) {
109qWarning() << "ERROR: could not resize hnswlib index:" << e.what();
110return false;
111}
112return true;
113}
114
115bool Embeddings::add(const std::vector<float> &embedding, qint64 label)
116{
117if (!isLoaded()) {
118bool success = load(500);
119if (!success) {
120qWarning() << "ERROR: attempting to add an embedding when the embeddings are not open!";
121return false;
122}
123}
124
125Q_ASSERT(m_hnsw);
126if (m_hnsw->cur_element_count + 1 > m_hnsw->max_elements_) {
127if (!resize(m_hnsw->max_elements_ + 500)) {
128return false;
129}
130}
131
132if (embedding.empty())
133return false;
134
135try {
136m_hnsw->addPoint(embedding.data(), label, false);
137} catch (const std::exception &e) {
138qWarning() << "ERROR: could not add embedding to hnswlib index:" << e.what();
139return false;
140}
141return true;
142}
143
144void Embeddings::remove(qint64 label)
145{
146if (!isLoaded()) {
147qWarning() << "ERROR: attempting to remove an embedding when the embeddings are not open!";
148return;
149}
150
151Q_ASSERT(m_hnsw);
152try {
153m_hnsw->markDelete(label);
154} catch (const std::exception &e) {
155qWarning() << "ERROR: could not add remove embedding from hnswlib index:" << e.what();
156}
157}
158
159void Embeddings::clear()
160{
161delete m_hnsw;
162m_hnsw = nullptr;
163delete m_space;
164m_space = nullptr;
165}
166
167std::vector<qint64> Embeddings::search(const std::vector<float> &embedding, int K)
168{
169if (!isLoaded())
170return {};
171
172Q_ASSERT(m_hnsw);
173std::priority_queue<std::pair<float, hnswlib::labeltype>> result;
174try {
175result = m_hnsw->searchKnn(embedding.data(), K);
176} catch (const std::exception &e) {
177qWarning() << "ERROR: could not search hnswlib index:" << e.what();
178return {};
179}
180
181std::vector<qint64> neighbors;
182neighbors.reserve(K);
183
184while(!result.empty()) {
185neighbors.push_back(result.top().second);
186result.pop();
187}
188
189// Reverse the neighbors, as the top of the priority queue is the farthest neighbor.
190std::reverse(neighbors.begin(), neighbors.end());
191
192return neighbors;
193}
194