gpt4all

Форк
0
/
chatgpt.cpp 
264 строки · 8.1 Кб
1
#include "chatgpt.h"
2

3
#include <string>
4
#include <vector>
5
#include <iostream>
6

7
#include <QCoreApplication>
8
#include <QThread>
9
#include <QEventLoop>
10
#include <QJsonDocument>
11
#include <QJsonObject>
12
#include <QJsonArray>
13

14
//#define DEBUG
15

16
ChatGPT::ChatGPT()
17
    : QObject(nullptr)
18
    , m_modelName("gpt-3.5-turbo")
19
    , m_responseCallback(nullptr)
20
{
21
}
22

23
size_t ChatGPT::requiredMem(const std::string &modelPath, int n_ctx, int ngl)
24
{
25
    Q_UNUSED(modelPath);
26
    Q_UNUSED(n_ctx);
27
    Q_UNUSED(ngl);
28
    return 0;
29
}
30

31
bool ChatGPT::loadModel(const std::string &modelPath, int n_ctx, int ngl)
32
{
33
    Q_UNUSED(modelPath);
34
    Q_UNUSED(n_ctx);
35
    Q_UNUSED(ngl);
36
    return true;
37
}
38

39
void ChatGPT::setThreadCount(int32_t n_threads)
40
{
41
    Q_UNUSED(n_threads);
42
    qt_noop();
43
}
44

45
int32_t ChatGPT::threadCount() const
46
{
47
    return 1;
48
}
49

50
ChatGPT::~ChatGPT()
51
{
52
}
53

54
bool ChatGPT::isModelLoaded() const
55
{
56
    return true;
57
}
58

59
// All three of the state virtual functions are handled custom inside of chatllm save/restore
60
size_t ChatGPT::stateSize() const
61
{
62
    return 0;
63
}
64

65
size_t ChatGPT::saveState(uint8_t *dest) const
66
{
67
    Q_UNUSED(dest);
68
    return 0;
69
}
70

71
size_t ChatGPT::restoreState(const uint8_t *src)
72
{
73
    Q_UNUSED(src);
74
    return 0;
75
}
76

77
void ChatGPT::prompt(const std::string &prompt,
78
        const std::string &promptTemplate,
79
        std::function<bool(int32_t)> promptCallback,
80
        std::function<bool(int32_t, const std::string&)> responseCallback,
81
        std::function<bool(bool)> recalculateCallback,
82
        PromptContext &promptCtx,
83
        bool special,
84
        std::string *fakeReply) {
85

86
    Q_UNUSED(promptCallback);
87
    Q_UNUSED(recalculateCallback);
88
    Q_UNUSED(special);
89
    Q_UNUSED(fakeReply); // FIXME(cebtenzzre): I broke ChatGPT
90

91
    if (!isModelLoaded()) {
92
        std::cerr << "ChatGPT ERROR: prompt won't work with an unloaded model!\n";
93
        return;
94
    }
95

96
    // FIXME: We don't set the max_tokens on purpose because in order to do so safely without encountering
97
    // an error we need to be able to count the tokens in our prompt. The only way to do this is to use
98
    // the OpenAI tiktokken library or to implement our own tokenization function that matches precisely
99
    // the tokenization used by the OpenAI model we're calling. OpenAI has not introduced any means of
100
    // using the REST API to count tokens in a prompt.
101
    QJsonObject root;
102
    root.insert("model", m_modelName);
103
    root.insert("stream", true);
104
    root.insert("temperature", promptCtx.temp);
105
    root.insert("top_p", promptCtx.top_p);
106

107
    QJsonArray messages;
108
    for (int i = 0; i < m_context.count() && i < promptCtx.n_past; ++i) {
109
        QJsonObject message;
110
        message.insert("role", i % 2 == 0 ? "assistant" : "user");
111
        message.insert("content", m_context.at(i));
112
        messages.append(message);
113
    }
114

115
    QJsonObject promptObject;
116
    promptObject.insert("role", "user");
117
    promptObject.insert("content", QString::fromStdString(promptTemplate).arg(QString::fromStdString(prompt)));
118
    messages.append(promptObject);
119
    root.insert("messages", messages);
120

121
    QJsonDocument doc(root);
122

123
#if defined(DEBUG)
124
    qDebug() << "ChatGPT::prompt begin network request" << qPrintable(doc.toJson());
125
#endif
126

127
    m_responseCallback = responseCallback;
128

129
    // The following code sets up a worker thread and object to perform the actual api request to
130
    // chatgpt and then blocks until it is finished
131
    QThread workerThread;
132
    ChatGPTWorker worker(this);
133
    worker.moveToThread(&workerThread);
134
    connect(&worker, &ChatGPTWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection);
135
    connect(this, &ChatGPT::request, &worker, &ChatGPTWorker::request, Qt::QueuedConnection);
136
    workerThread.start();
137
    emit request(m_apiKey, &promptCtx, doc.toJson(QJsonDocument::Compact));
138
    workerThread.wait();
139

140
    promptCtx.n_past += 1;
141
    m_context.append(QString::fromStdString(prompt));
142
    m_context.append(worker.currentResponse());
143
    m_responseCallback = nullptr;
144

145
#if defined(DEBUG)
146
    qDebug() << "ChatGPT::prompt end network request";
147
#endif
148
}
149

150
bool ChatGPT::callResponse(int32_t token, const std::string& string)
151
{
152
    Q_ASSERT(m_responseCallback);
153
    if (!m_responseCallback) {
154
        std::cerr << "ChatGPT ERROR: no response callback!\n";
155
        return false;
156
    }
157
    return m_responseCallback(token, string);
158
}
159

160
void ChatGPTWorker::request(const QString &apiKey,
161
        LLModel::PromptContext *promptCtx,
162
        const QByteArray &array)
163
{
164
    m_ctx = promptCtx;
165

166
    QUrl openaiUrl("https://api.openai.com/v1/chat/completions");
167
    const QString authorization = QString("Bearer %1").arg(apiKey).trimmed();
168
    QNetworkRequest request(openaiUrl);
169
    request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
170
    request.setRawHeader("Authorization", authorization.toUtf8());
171
    m_networkManager = new QNetworkAccessManager(this);
172
    QNetworkReply *reply = m_networkManager->post(request, array);
173
    connect(qApp, &QCoreApplication::aboutToQuit, reply, &QNetworkReply::abort);
174
    connect(reply, &QNetworkReply::finished, this, &ChatGPTWorker::handleFinished);
175
    connect(reply, &QNetworkReply::readyRead, this, &ChatGPTWorker::handleReadyRead);
176
    connect(reply, &QNetworkReply::errorOccurred, this, &ChatGPTWorker::handleErrorOccurred);
177
}
178

179
void ChatGPTWorker::handleFinished()
180
{
181
    QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
182
    if (!reply) {
183
        emit finished();
184
        return;
185
    }
186

187
    QVariant response = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute);
188
    Q_ASSERT(response.isValid());
189
    bool ok;
190
    int code = response.toInt(&ok);
191
    if (!ok || code != 200) {
192
        qWarning() << QString("ERROR: ChatGPT responded with error code \"%1-%2\"")
193
            .arg(code).arg(reply->errorString()).toStdString();
194
    }
195
    reply->deleteLater();
196
    emit finished();
197
}
198

199
void ChatGPTWorker::handleReadyRead()
200
{
201
    QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
202
    if (!reply) {
203
        emit finished();
204
        return;
205
    }
206

207
    QVariant response = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute);
208
    Q_ASSERT(response.isValid());
209
    bool ok;
210
    int code = response.toInt(&ok);
211
    if (!ok || code != 200) {
212
        m_chat->callResponse(-1, QString("\nERROR: 2 ChatGPT responded with error code \"%1-%2\" %3\n")
213
            .arg(code).arg(reply->errorString()).arg(qPrintable(reply->readAll())).toStdString());
214
        emit finished();
215
        return;
216
    }
217

218
    while (reply->canReadLine()) {
219
        QString jsonData = reply->readLine().trimmed();
220
        if (jsonData.startsWith("data:"))
221
            jsonData.remove(0, 5);
222
        jsonData = jsonData.trimmed();
223
        if (jsonData.isEmpty())
224
            continue;
225
        if (jsonData == "[DONE]")
226
            continue;
227
#if defined(DEBUG)
228
        qDebug() << "line" << qPrintable(jsonData);
229
#endif
230
        QJsonParseError err;
231
        const QJsonDocument document = QJsonDocument::fromJson(jsonData.toUtf8(), &err);
232
        if (err.error != QJsonParseError::NoError) {
233
            m_chat->callResponse(-1, QString("\nERROR: ChatGPT responded with invalid json \"%1\"\n")
234
                .arg(err.errorString()).toStdString());
235
            continue;
236
        }
237

238
        const QJsonObject root = document.object();
239
        const QJsonArray choices = root.value("choices").toArray();
240
        const QJsonObject choice = choices.first().toObject();
241
        const QJsonObject delta = choice.value("delta").toObject();
242
        const QString content = delta.value("content").toString();
243
        Q_ASSERT(m_ctx);
244
        m_currentResponse += content;
245
        if (!m_chat->callResponse(0, content.toStdString())) {
246
            reply->abort();
247
            emit finished();
248
            return;
249
        }
250
    }
251
}
252

253
void ChatGPTWorker::handleErrorOccurred(QNetworkReply::NetworkError code)
254
{
255
    QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
256
    if (!reply || reply->error() == QNetworkReply::OperationCanceledError /*when we call abort on purpose*/) {
257
        emit finished();
258
        return;
259
    }
260

261
    qWarning() << QString("ERROR: ChatGPT responded with error code \"%1-%2\"")
262
                      .arg(code).arg(reply->errorString()).toStdString();
263
    emit finished();
264
}
265

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.