gpt4all
264 строки · 8.1 Кб
1#include "chatgpt.h"
2
3#include <string>
4#include <vector>
5#include <iostream>
6
7#include <QCoreApplication>
8#include <QThread>
9#include <QEventLoop>
10#include <QJsonDocument>
11#include <QJsonObject>
12#include <QJsonArray>
13
14//#define DEBUG
15
16ChatGPT::ChatGPT()
17: QObject(nullptr)
18, m_modelName("gpt-3.5-turbo")
19, m_responseCallback(nullptr)
20{
21}
22
23size_t ChatGPT::requiredMem(const std::string &modelPath, int n_ctx, int ngl)
24{
25Q_UNUSED(modelPath);
26Q_UNUSED(n_ctx);
27Q_UNUSED(ngl);
28return 0;
29}
30
31bool ChatGPT::loadModel(const std::string &modelPath, int n_ctx, int ngl)
32{
33Q_UNUSED(modelPath);
34Q_UNUSED(n_ctx);
35Q_UNUSED(ngl);
36return true;
37}
38
39void ChatGPT::setThreadCount(int32_t n_threads)
40{
41Q_UNUSED(n_threads);
42qt_noop();
43}
44
45int32_t ChatGPT::threadCount() const
46{
47return 1;
48}
49
50ChatGPT::~ChatGPT()
51{
52}
53
54bool ChatGPT::isModelLoaded() const
55{
56return true;
57}
58
59// All three of the state virtual functions are handled custom inside of chatllm save/restore
60size_t ChatGPT::stateSize() const
61{
62return 0;
63}
64
65size_t ChatGPT::saveState(uint8_t *dest) const
66{
67Q_UNUSED(dest);
68return 0;
69}
70
71size_t ChatGPT::restoreState(const uint8_t *src)
72{
73Q_UNUSED(src);
74return 0;
75}
76
77void ChatGPT::prompt(const std::string &prompt,
78const std::string &promptTemplate,
79std::function<bool(int32_t)> promptCallback,
80std::function<bool(int32_t, const std::string&)> responseCallback,
81std::function<bool(bool)> recalculateCallback,
82PromptContext &promptCtx,
83bool special,
84std::string *fakeReply) {
85
86Q_UNUSED(promptCallback);
87Q_UNUSED(recalculateCallback);
88Q_UNUSED(special);
89Q_UNUSED(fakeReply); // FIXME(cebtenzzre): I broke ChatGPT
90
91if (!isModelLoaded()) {
92std::cerr << "ChatGPT ERROR: prompt won't work with an unloaded model!\n";
93return;
94}
95
96// FIXME: We don't set the max_tokens on purpose because in order to do so safely without encountering
97// an error we need to be able to count the tokens in our prompt. The only way to do this is to use
98// the OpenAI tiktokken library or to implement our own tokenization function that matches precisely
99// the tokenization used by the OpenAI model we're calling. OpenAI has not introduced any means of
100// using the REST API to count tokens in a prompt.
101QJsonObject root;
102root.insert("model", m_modelName);
103root.insert("stream", true);
104root.insert("temperature", promptCtx.temp);
105root.insert("top_p", promptCtx.top_p);
106
107QJsonArray messages;
108for (int i = 0; i < m_context.count() && i < promptCtx.n_past; ++i) {
109QJsonObject message;
110message.insert("role", i % 2 == 0 ? "assistant" : "user");
111message.insert("content", m_context.at(i));
112messages.append(message);
113}
114
115QJsonObject promptObject;
116promptObject.insert("role", "user");
117promptObject.insert("content", QString::fromStdString(promptTemplate).arg(QString::fromStdString(prompt)));
118messages.append(promptObject);
119root.insert("messages", messages);
120
121QJsonDocument doc(root);
122
123#if defined(DEBUG)
124qDebug() << "ChatGPT::prompt begin network request" << qPrintable(doc.toJson());
125#endif
126
127m_responseCallback = responseCallback;
128
129// The following code sets up a worker thread and object to perform the actual api request to
130// chatgpt and then blocks until it is finished
131QThread workerThread;
132ChatGPTWorker worker(this);
133worker.moveToThread(&workerThread);
134connect(&worker, &ChatGPTWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection);
135connect(this, &ChatGPT::request, &worker, &ChatGPTWorker::request, Qt::QueuedConnection);
136workerThread.start();
137emit request(m_apiKey, &promptCtx, doc.toJson(QJsonDocument::Compact));
138workerThread.wait();
139
140promptCtx.n_past += 1;
141m_context.append(QString::fromStdString(prompt));
142m_context.append(worker.currentResponse());
143m_responseCallback = nullptr;
144
145#if defined(DEBUG)
146qDebug() << "ChatGPT::prompt end network request";
147#endif
148}
149
150bool ChatGPT::callResponse(int32_t token, const std::string& string)
151{
152Q_ASSERT(m_responseCallback);
153if (!m_responseCallback) {
154std::cerr << "ChatGPT ERROR: no response callback!\n";
155return false;
156}
157return m_responseCallback(token, string);
158}
159
160void ChatGPTWorker::request(const QString &apiKey,
161LLModel::PromptContext *promptCtx,
162const QByteArray &array)
163{
164m_ctx = promptCtx;
165
166QUrl openaiUrl("https://api.openai.com/v1/chat/completions");
167const QString authorization = QString("Bearer %1").arg(apiKey).trimmed();
168QNetworkRequest request(openaiUrl);
169request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
170request.setRawHeader("Authorization", authorization.toUtf8());
171m_networkManager = new QNetworkAccessManager(this);
172QNetworkReply *reply = m_networkManager->post(request, array);
173connect(qApp, &QCoreApplication::aboutToQuit, reply, &QNetworkReply::abort);
174connect(reply, &QNetworkReply::finished, this, &ChatGPTWorker::handleFinished);
175connect(reply, &QNetworkReply::readyRead, this, &ChatGPTWorker::handleReadyRead);
176connect(reply, &QNetworkReply::errorOccurred, this, &ChatGPTWorker::handleErrorOccurred);
177}
178
179void ChatGPTWorker::handleFinished()
180{
181QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
182if (!reply) {
183emit finished();
184return;
185}
186
187QVariant response = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute);
188Q_ASSERT(response.isValid());
189bool ok;
190int code = response.toInt(&ok);
191if (!ok || code != 200) {
192qWarning() << QString("ERROR: ChatGPT responded with error code \"%1-%2\"")
193.arg(code).arg(reply->errorString()).toStdString();
194}
195reply->deleteLater();
196emit finished();
197}
198
199void ChatGPTWorker::handleReadyRead()
200{
201QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
202if (!reply) {
203emit finished();
204return;
205}
206
207QVariant response = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute);
208Q_ASSERT(response.isValid());
209bool ok;
210int code = response.toInt(&ok);
211if (!ok || code != 200) {
212m_chat->callResponse(-1, QString("\nERROR: 2 ChatGPT responded with error code \"%1-%2\" %3\n")
213.arg(code).arg(reply->errorString()).arg(qPrintable(reply->readAll())).toStdString());
214emit finished();
215return;
216}
217
218while (reply->canReadLine()) {
219QString jsonData = reply->readLine().trimmed();
220if (jsonData.startsWith("data:"))
221jsonData.remove(0, 5);
222jsonData = jsonData.trimmed();
223if (jsonData.isEmpty())
224continue;
225if (jsonData == "[DONE]")
226continue;
227#if defined(DEBUG)
228qDebug() << "line" << qPrintable(jsonData);
229#endif
230QJsonParseError err;
231const QJsonDocument document = QJsonDocument::fromJson(jsonData.toUtf8(), &err);
232if (err.error != QJsonParseError::NoError) {
233m_chat->callResponse(-1, QString("\nERROR: ChatGPT responded with invalid json \"%1\"\n")
234.arg(err.errorString()).toStdString());
235continue;
236}
237
238const QJsonObject root = document.object();
239const QJsonArray choices = root.value("choices").toArray();
240const QJsonObject choice = choices.first().toObject();
241const QJsonObject delta = choice.value("delta").toObject();
242const QString content = delta.value("content").toString();
243Q_ASSERT(m_ctx);
244m_currentResponse += content;
245if (!m_chat->callResponse(0, content.toStdString())) {
246reply->abort();
247emit finished();
248return;
249}
250}
251}
252
253void ChatGPTWorker::handleErrorOccurred(QNetworkReply::NetworkError code)
254{
255QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
256if (!reply || reply->error() == QNetworkReply::OperationCanceledError /*when we call abort on purpose*/) {
257emit finished();
258return;
259}
260
261qWarning() << QString("ERROR: ChatGPT responded with error code \"%1-%2\"")
262.arg(code).arg(reply->errorString()).toStdString();
263emit finished();
264}
265