llama

Форк
0
/
gritlm.cpp 
226 строк · 9.6 Кб
1
#include "arg.h"
2
#include "common.h"
3
#include "llama.h"
4

5
#include <string>
6
#include <vector>
7

8
// #define GRIT_DEBUG
9

10
static std::vector<std::vector<float>> encode(llama_context * ctx, const std::vector<std::string> & sentences, const std::string & instruction) {
11
    std::vector<std::vector<float>> result;
12

13
    const llama_model * model = llama_get_model(ctx);
14

15
    llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
16

17
    for (uint64_t i = 0; i < sentences.size(); i++) {
18
        llama_batch_clear(batch);
19

20
        const std::string input_string = instruction + sentences[i];
21

22
        std::vector<llama_token> inputs = llama_tokenize(model, input_string, true, false);
23

24
        const int32_t n_toks = inputs.size();
25

26
        // GritLM seems to have EOS = ""
27
        // https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18
28
        // inputs.push_back(llama_token_eos(model));
29

30
        // we want to ignore instruction tokens for mean pooling
31
        const int32_t n_inst = llama_tokenize(model, instruction, true, false).size();
32

33
#ifdef GRIT_DEBUG
34
        // debug tokens - should be matching as referenced in the GritLM sample
35
        std::for_each(inputs.begin(), inputs.end(), [&ctx](llama_token t) {
36
            std::printf("[%u:%s]", t, llama_token_to_piece(ctx, t).c_str());
37
        });
38
        std::printf("\n");
39
#endif
40

41
        // add input to batch (this increments n_tokens)
42
        for (int32_t j = 0; j < n_toks; j++) {
43
            llama_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst);
44
        }
45

46
        // clear previous kv_cache values (irrelevant for embeddings)
47
        llama_kv_cache_clear(ctx);
48
        llama_set_embeddings(ctx, true);
49
        llama_set_causal_attn(ctx, false);
50

51
        // run model
52
        llama_decode(ctx, batch);
53

54
        // get embedding dimensions
55
        uint64_t n_embd = llama_n_embd(model);
56

57
        // allocate embedding output
58
        std::vector<float> emb_unorm(n_embd, 0.0f);
59

60
        // sum up all token embeddings
61
        for (int32_t k = n_inst; k < n_toks; k++) {
62
            float * emb = llama_get_embeddings_ith(ctx, k);
63
            for (uint64_t j = 0; j < n_embd; j++) {
64
                emb_unorm[j] += emb[j];
65
            }
66
        }
67

68
        // divide by number of tokens (mean pooling)
69
        {
70
            const uint64_t n_sent = n_toks - n_inst;
71

72
            for (uint64_t j = 0; j < n_embd; j++) {
73
                emb_unorm[j] /= n_sent;
74
            }
75
        }
76

77
        std::vector<float> emb_norm(emb_unorm.size());
78
        llama_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd);
79
        result.push_back(emb_norm);
80

81
#ifdef GRIT_DEBUG
82
        // print out emb_norm
83
        std::printf("embedding %ld: ", i);
84
        for (uint64_t j = 0; j < n_embd; j++) {
85
            std::printf("%.5f ", emb_norm[j]);
86
        }
87
        std::printf("\n\n");
88
#endif
89
    }
90

91
    llama_batch_free(batch);
92

93
    return result;
94
}
95

96
static std::string generate(llama_context * ctx, llama_sampler * smpl, const std::string & prompt, bool stream) {
97
    std::string result;
98

99
    const llama_model * model = llama_get_model(ctx);
100
    llama_token eos_token = llama_token_eos(model);
101

102
    llama_kv_cache_clear(ctx);
103
    llama_set_embeddings(ctx, false);
104
    llama_set_causal_attn(ctx, true);
105

106
    llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
107

108
    std::vector<llama_token> inputs = llama_tokenize(model, prompt, false, true);
109
    int32_t i_current_token = 0;
110

111
    while (true) {
112
        llama_batch_clear(bat);
113
        {
114
            const int32_t n_inputs = inputs.size();
115

116
            for (int32_t i = 0; i < n_inputs; i++) {
117
                llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1);
118
            }
119
        }
120
        inputs.clear();
121

122
        llama_decode(ctx, bat);
123

124
        llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1);
125

126
        if (token == eos_token) {
127
            break;
128
        }
129

130
        std::string piece = llama_token_to_piece(ctx, token);
131
        if (stream) {
132
            std::printf("%s", piece.c_str());
133
            std::fflush(stdout);
134
        }
135

136
        inputs.push_back(token);
137

138
        result += piece;
139
    }
140

141
    if (stream) {
142
        std::printf("\n");
143
    }
144

145
    llama_batch_free(bat);
146

147
    return result;
148
}
149

150
static std::string gritlm_instruction(const std::string & instruction) {
151
    return !instruction.empty() ? "<|user|>\n" + instruction + "\n<|embed|>\n" : "<|embed|>\n";
152
}
153

154
int main(int argc, char * argv[]) {
155
    gpt_params params;
156

157
    if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
158
        return 1;
159
    }
160

161
    gpt_init();
162

163
    llama_model_params mparams = llama_model_params_from_gpt_params(params);
164
    llama_context_params cparams = llama_context_params_from_gpt_params(params);
165

166
    llama_backend_init();
167

168
    llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams);
169

170
    // create generation context
171
    llama_context * ctx = llama_new_context_with_model(model, cparams);
172

173
    auto sparams = llama_sampler_chain_default_params();
174

175
    sparams.no_perf = false;
176

177
    llama_sampler * smpl = llama_sampler_chain_init(sparams);
178

179
    llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
180

181
    // ### Embedding/Representation ###
182
    // samples taken from: https://github.com/ContextualAI/gritlm#basic
183
    {
184
        const std::string instruction = "Given a scientific paper title, retrieve the paper's abstract";
185

186
        const std::vector<std::string> queries = {
187
            "Bitcoin: A Peer-to-Peer Electronic Cash System",
188
            "Generative Representational Instruction Tuning",
189
        };
190

191
        const std::vector<std::string> documents = {
192
            "A purely peer-to-peer version of electronic cash would allow online payments to be sent directly from one party to another without going through a financial institution. Digital signatures provide part of the solution, but the main benefits are lost if a trusted third party is still required to prevent double-spending. We propose a solution to the double-spending problem using a peer-to-peer network. The network timestamps transactions by hashing them into an ongoing chain of hash-based proof-of-work, forming a record that cannot be changed without redoing the proof-of-work. The longest chain not only serves as proof of the sequence of events witnessed, but proof that it came from the largest pool of CPU power. As long as a majority of CPU power is controlled by nodes that are not cooperating to attack the network, they'll generate the longest chain and outpace attackers. The network itself requires minimal structure. Messages are broadcast on a best effort basis, and nodes can leave and rejoin the network at will, accepting the longest proof-of-work chain as proof of what happened while they were gone.",
193
            "All text-based language problems can be reduced to either generation or embedding. Current models only perform well at one or the other. We introduce generative representational instruction tuning (GRIT) whereby a large language model is trained to handle both generative and embedding tasks by distinguishing between them through instructions. Compared to other open models, our resulting GritLM 7B sets a new state of the art on the Massive Text Embedding Benchmark (MTEB) and outperforms all models up to its size on a range of generative tasks. By scaling up further, GritLM 8X7B outperforms all open generative language models that we tried while still being among the best embedding models. Notably, we find that GRIT matches training on only generative or embedding data, thus we can unify both at no performance loss. Among other benefits, the unification via GRIT speeds up Retrieval-Augmented Generation (RAG) by > 60% for long documents, by no longer requiring separate retrieval and generation models. Models, code, etc. are freely available at https://github.com/ContextualAI/gritlm.",
194
        };
195

196
        // No need to add instruction for retrieval documents
197
        const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
198
        const std::vector<std::vector<float>> q_rep = encode(ctx, queries,   gritlm_instruction(instruction));
199

200
        const int n_embd = llama_n_embd(model);
201

202
        const float cosine_sim_q0_d0 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd);
203
        const float cosine_sim_q0_d1 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd);
204
        const float cosine_sim_q1_d0 = llama_embd_similarity_cos(q_rep[1].data(), d_rep[0].data(), n_embd);
205
        const float cosine_sim_q1_d1 = llama_embd_similarity_cos(q_rep[1].data(), d_rep[1].data(), n_embd);
206

207
        std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[0].c_str(), cosine_sim_q0_d0);
208
        std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[1].c_str(), cosine_sim_q0_d1);
209
        std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[0].c_str(), cosine_sim_q1_d0);
210
        std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[1].c_str(), cosine_sim_q1_d1);
211
    }
212

213
    // ### Generation ###
214
    // GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
215
    {
216
        const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n";
217
        std::string response = generate(ctx, smpl, prompt, true);
218
    }
219

220
    llama_sampler_free(smpl);
221
    llama_free(ctx);
222
    llama_free_model(model);
223
    llama_backend_free();
224

225
    return 0;
226
}
227

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.