ollama

Форк
0
/
dyn_ext_server.go 
389 строк · 11.1 Кб
1
package llm
2

3
/*
4
#cgo CFLAGS: -I${SRCDIR}/ext_server -I${SRCDIR}/llama.cpp -I${SRCDIR}/llama.cpp/common -I${SRCDIR}/llama.cpp/examples/server
5
#cgo CFLAGS: -DNDEBUG -DLLAMA_SERVER_LIBRARY=1 -D_XOPEN_SOURCE=600 -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
6
#cgo CFLAGS: -Wmissing-noreturn -Wextra -Wcast-qual -Wno-unused-function -Wno-array-bounds
7
#cgo CPPFLAGS: -Ofast -Wextra -Wno-unused-function -Wno-unused-variable -Wno-deprecated-declarations
8
#cgo darwin CFLAGS: -D_DARWIN_C_SOURCE
9
#cgo darwin CPPFLAGS:  -DGGML_USE_ACCELERATE
10
#cgo darwin CPPFLAGS: -DGGML_USE_METAL -DGGML_METAL_NDEBUG
11
#cgo darwin LDFLAGS: -lc++ -framework Accelerate
12
#cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
13
#cgo linux CFLAGS: -D_GNU_SOURCE
14
#cgo linux LDFLAGS: -lrt -ldl -lstdc++ -lm
15
#cgo linux windows LDFLAGS: -lpthread
16

17
#include <stdlib.h>
18
#include "dyn_ext_server.h"
19

20
*/
21
import "C"
22

23
import (
24
	"bytes"
25
	"context"
26
	"encoding/json"
27
	"fmt"
28
	"log/slog"
29
	"os"
30
	"path/filepath"
31
	"runtime"
32
	"strings"
33
	"sync"
34
	"time"
35
	"unsafe"
36

37
	"github.com/jmorganca/ollama/api"
38
)
39

40
type dynExtServer struct {
41
	s       C.struct_dynamic_llama_server
42
	options api.Options
43
}
44

45
// Note: current implementation does not support concurrent instantiations
46
var mutex sync.Mutex
47

48
func newExtServerResp(len C.size_t) C.ext_server_resp_t {
49
	var resp C.ext_server_resp_t
50
	resp.msg_len = len
51
	bytes := make([]byte, len)
52
	resp.msg = (*C.char)(C.CBytes(bytes))
53
	return resp
54
}
55

56
func freeExtServerResp(resp C.ext_server_resp_t) {
57
	if resp.msg_len == 0 {
58
		return
59
	}
60
	C.free(unsafe.Pointer(resp.msg))
61
}
62

63
func extServerResponseToErr(resp C.ext_server_resp_t) error {
64
	return fmt.Errorf(C.GoString(resp.msg))
65
}
66

67
// Note: current implementation does not support concurrent instantiations
68
var llm *dynExtServer
69

70
func newDynExtServer(library, model string, adapters, projectors []string, opts api.Options) (LLM, error) {
71
	if !mutex.TryLock() {
72
		slog.Info("concurrent llm servers not yet supported, waiting for prior server to complete")
73
		mutex.Lock()
74
	}
75
	updatePath(filepath.Dir(library))
76
	libPath := C.CString(library)
77
	defer C.free(unsafe.Pointer(libPath))
78
	resp := newExtServerResp(512)
79
	defer freeExtServerResp(resp)
80
	var srv C.struct_dynamic_llama_server
81
	C.dyn_init(libPath, &srv, &resp)
82
	if resp.id < 0 {
83
		mutex.Unlock()
84
		return nil, fmt.Errorf("Unable to load dynamic library: %s", C.GoString(resp.msg))
85
	}
86
	llm = &dynExtServer{
87
		s:       srv,
88
		options: opts,
89
	}
90
	slog.Info(fmt.Sprintf("Loading Dynamic llm server: %s", library))
91

92
	var sparams C.ext_server_params_t
93
	sparams.model = C.CString(model)
94
	defer C.free(unsafe.Pointer(sparams.model))
95

96
	sparams.embedding = true
97
	sparams.n_ctx = C.uint(opts.NumCtx)
98
	sparams.n_batch = C.uint(opts.NumBatch)
99
	sparams.n_gpu_layers = C.int(opts.NumGPU)
100
	sparams.main_gpu = C.int(opts.MainGPU)
101
	sparams.n_parallel = 1 // TODO - wire up concurrency
102

103
	// Always use the value encoded in the model
104
	sparams.rope_freq_base = 0.0
105
	sparams.rope_freq_scale = 0.0
106
	sparams.memory_f16 = C.bool(opts.F16KV)
107
	sparams.use_mlock = C.bool(opts.UseMLock)
108
	sparams.use_mmap = C.bool(opts.UseMMap)
109

110
	if opts.UseNUMA {
111
		sparams.numa = C.int(1)
112
	} else {
113
		sparams.numa = C.int(0)
114
	}
115

116
	sparams.lora_adapters = nil
117
	for i := 0; i < len(adapters); i++ {
118
		la := (*C.ext_server_lora_adapter_t)(C.malloc(C.sizeof_ext_server_lora_adapter_t))
119
		defer C.free(unsafe.Pointer(la))
120
		la.adapter = C.CString(adapters[i])
121
		defer C.free(unsafe.Pointer(la.adapter))
122
		la.scale = C.float(1.0) // TODO expose scale/weights up through ollama UX
123
		la.next = nil
124
		if i == 0 {
125
			sparams.lora_adapters = la
126
		} else {
127
			tmp := sparams.lora_adapters
128
			for ; tmp.next != nil; tmp = tmp.next {
129
			}
130
			tmp.next = la
131
		}
132
	}
133

134
	if len(projectors) > 0 {
135
		// TODO: applying multiple projectors is not supported by the llama.cpp server yet
136
		sparams.mmproj = C.CString(projectors[0])
137
		defer C.free(unsafe.Pointer(sparams.mmproj))
138
	} else {
139
		sparams.mmproj = nil
140
	}
141

142
	sparams.n_threads = C.uint(opts.NumThread)
143

144
	if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
145
		sparams.verbose_logging = C.bool(true)
146
	} else {
147
		sparams.verbose_logging = C.bool(false)
148
	}
149

150
	slog.Info("Initializing llama server")
151
	initResp := newExtServerResp(128)
152
	defer freeExtServerResp(initResp)
153
	C.dyn_llama_server_init(llm.s, &sparams, &initResp)
154
	if initResp.id < 0 {
155
		mutex.Unlock()
156
		err := extServerResponseToErr(initResp)
157
		slog.Debug(fmt.Sprintf("failure during initialization: %s", err))
158
		return nil, err
159
	}
160

161
	slog.Info("Starting llama main loop")
162
	C.dyn_llama_server_start(llm.s)
163
	return llm, nil
164
}
165

166
func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
167
	resp := newExtServerResp(128)
168
	defer freeExtServerResp(resp)
169

170
	if len(predict.Images) > 0 {
171
		slog.Info(fmt.Sprintf("loaded %d images", len(predict.Images)))
172
	}
173

174
	request := map[string]any{
175
		"prompt":            predict.Prompt,
176
		"stream":            true,
177
		"n_predict":         predict.Options.NumPredict,
178
		"n_keep":            predict.Options.NumKeep,
179
		"temperature":       predict.Options.Temperature,
180
		"top_k":             predict.Options.TopK,
181
		"top_p":             predict.Options.TopP,
182
		"tfs_z":             predict.Options.TFSZ,
183
		"typical_p":         predict.Options.TypicalP,
184
		"repeat_last_n":     predict.Options.RepeatLastN,
185
		"repeat_penalty":    predict.Options.RepeatPenalty,
186
		"presence_penalty":  predict.Options.PresencePenalty,
187
		"frequency_penalty": predict.Options.FrequencyPenalty,
188
		"mirostat":          predict.Options.Mirostat,
189
		"mirostat_tau":      predict.Options.MirostatTau,
190
		"mirostat_eta":      predict.Options.MirostatEta,
191
		"penalize_nl":       predict.Options.PenalizeNewline,
192
		"seed":              predict.Options.Seed,
193
		"stop":              predict.Options.Stop,
194
		"image_data":        predict.Images,
195
		"cache_prompt":      true,
196
	}
197

198
	if predict.Format == "json" {
199
		request["grammar"] = jsonGrammar
200
	}
201

202
	retryDelay := 100 * time.Microsecond
203
	for retries := 0; retries < maxRetries; retries++ {
204
		if retries > 0 {
205
			time.Sleep(retryDelay) // wait before retrying
206
			retryDelay *= 2        // exponential backoff
207
		}
208

209
		// Handling JSON marshaling with special characters unescaped.
210
		buffer := &bytes.Buffer{}
211
		enc := json.NewEncoder(buffer)
212
		enc.SetEscapeHTML(false)
213

214
		if err := enc.Encode(request); err != nil {
215
			return fmt.Errorf("failed to marshal data: %w", err)
216
		}
217

218
		req := C.CString(buffer.String())
219
		defer C.free(unsafe.Pointer(req))
220

221
		C.dyn_llama_server_completion(llm.s, req, &resp)
222
		if resp.id < 0 {
223
			return extServerResponseToErr(resp)
224
		}
225

226
		retryNeeded := false
227
	out:
228
		for {
229
			select {
230
			case <-ctx.Done():
231
				// This handles the request cancellation
232
				C.dyn_llama_server_completion_cancel(llm.s, resp.id, &resp)
233
				if resp.id < 0 {
234
					return extServerResponseToErr(resp)
235
				} else {
236
					return nil
237
				}
238
			default:
239
				var result C.ext_server_task_result_t
240
				C.dyn_llama_server_completion_next_result(llm.s, resp.id, &result)
241
				json_resp := C.GoString(result.json_resp)
242
				C.dyn_llama_server_release_task_result(llm.s, &result)
243

244
				var p prediction
245
				if err := json.Unmarshal([]byte(json_resp), &p); err != nil {
246
					C.dyn_llama_server_completion_cancel(llm.s, resp.id, &resp)
247
					if resp.id < 0 {
248
						return fmt.Errorf("error unmarshaling llm prediction response: %w and cancel %s", err, C.GoString(resp.msg))
249
					} else {
250
						return fmt.Errorf("error unmarshaling llm prediction response: %w", err)
251
					}
252
				}
253

254
				if bool(result.error) && strings.Contains(json_resp, "slot unavailable") {
255
					retryNeeded = true
256
					// task will already be canceled
257
					break out
258
				}
259

260
				if p.Content != "" {
261
					fn(PredictResult{
262
						Content: p.Content,
263
					})
264
				}
265

266
				if p.Stop || bool(result.stop) {
267
					fn(PredictResult{
268
						Done:               true,
269
						PromptEvalCount:    p.Timings.PromptN,
270
						PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
271
						EvalCount:          p.Timings.PredictedN,
272
						EvalDuration:       parseDurationMs(p.Timings.PredictedMS),
273
					})
274
					return nil
275
				}
276
			}
277
		}
278
		if !retryNeeded {
279
			return nil // success
280
		}
281
	}
282

283
	// should never reach here ideally
284
	return fmt.Errorf("max retries exceeded")
285
}
286

287
func (llm *dynExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {
288
	data, err := json.Marshal(TokenizeRequest{Content: prompt})
289
	if err != nil {
290
		return nil, fmt.Errorf("marshaling encode data: %w", err)
291
	}
292
	req := C.CString(string(data))
293
	defer C.free(unsafe.Pointer(req))
294
	var json_resp *C.char
295
	resp := newExtServerResp(128)
296
	defer freeExtServerResp(resp)
297
	C.dyn_llama_server_tokenize(llm.s, req, &json_resp, &resp)
298
	if resp.id < 0 {
299
		return nil, extServerResponseToErr(resp)
300
	}
301
	defer C.dyn_llama_server_release_json_resp(llm.s, &json_resp)
302

303
	var encoded TokenizeResponse
304
	if err2 := json.Unmarshal([]byte(C.GoString(json_resp)), &encoded); err2 != nil {
305
		return nil, fmt.Errorf("unmarshal encode response: %w", err2)
306
	}
307

308
	return encoded.Tokens, err
309
}
310

311
func (llm *dynExtServer) Decode(ctx context.Context, tokens []int) (string, error) {
312
	if len(tokens) == 0 {
313
		return "", nil
314
	}
315
	data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
316
	if err != nil {
317
		return "", fmt.Errorf("marshaling decode data: %w", err)
318
	}
319

320
	req := C.CString(string(data))
321
	defer C.free(unsafe.Pointer(req))
322
	var json_resp *C.char
323
	resp := newExtServerResp(128)
324
	defer freeExtServerResp(resp)
325
	C.dyn_llama_server_detokenize(llm.s, req, &json_resp, &resp)
326
	if resp.id < 0 {
327
		return "", extServerResponseToErr(resp)
328
	}
329
	defer C.dyn_llama_server_release_json_resp(llm.s, &json_resp)
330

331
	var decoded DetokenizeResponse
332
	if err2 := json.Unmarshal([]byte(C.GoString(json_resp)), &decoded); err2 != nil {
333
		return "", fmt.Errorf("unmarshal encode response: %w", err2)
334
	}
335

336
	return decoded.Content, err
337
}
338

339
func (llm *dynExtServer) Embedding(ctx context.Context, input string) ([]float64, error) {
340
	data, err := json.Marshal(TokenizeRequest{Content: input})
341
	if err != nil {
342
		return nil, fmt.Errorf("error marshaling embed data: %w", err)
343
	}
344

345
	req := C.CString(string(data))
346
	defer C.free(unsafe.Pointer(req))
347
	var json_resp *C.char
348
	resp := newExtServerResp(128)
349
	defer freeExtServerResp(resp)
350
	C.dyn_llama_server_embedding(llm.s, req, &json_resp, &resp)
351
	if resp.id < 0 {
352
		return nil, extServerResponseToErr(resp)
353
	}
354
	defer C.dyn_llama_server_release_json_resp(llm.s, &json_resp)
355

356
	var embedding EmbeddingResponse
357
	if err := json.Unmarshal([]byte(C.GoString(json_resp)), &embedding); err != nil {
358
		return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
359
	}
360

361
	return embedding.Embedding, nil
362
}
363

364
func (llm *dynExtServer) Close() {
365
	C.dyn_llama_server_stop(llm.s)
366
	mutex.Unlock()
367
}
368

369
func updatePath(dir string) {
370
	if runtime.GOOS == "windows" {
371
		tmpDir := filepath.Dir(dir)
372
		pathComponents := strings.Split(os.Getenv("PATH"), ";")
373
		i := 0
374
		for _, comp := range pathComponents {
375
			if strings.EqualFold(comp, dir) {
376
				return
377
			}
378
			// Remove any other prior paths to our temp dir
379
			if !strings.HasPrefix(strings.ToLower(comp), strings.ToLower(tmpDir)) {
380
				pathComponents[i] = comp
381
				i++
382
			}
383
		}
384
		newPath := strings.Join(append([]string{dir}, pathComponents...), ";")
385
		slog.Info(fmt.Sprintf("Updating PATH to %s", newPath))
386
		os.Setenv("PATH", newPath)
387
	}
388
	// linux and darwin rely on rpath
389
}
390

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.