Langchain-Chatchat

Форк
0
123 строки · 4.3 Кб
1
import sys
2
from fastchat.conversation import Conversation
3
from server.model_workers.base import *
4
from server.utils import get_httpx_client
5
from fastchat import conversation as conv
6
import json, httpx
7
from typing import List, Dict
8
from configs import logger, log_verbose
9

10

11
class GeminiWorker(ApiModelWorker):
12
    def __init__(
13
            self,
14
            *,
15
            controller_addr: str = None,
16
            worker_addr: str = None,
17
            model_names: List[str] = ["gemini-api"],
18
            **kwargs,
19
    ):
20
        kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
21
        kwargs.setdefault("context_len", 4096)
22
        super().__init__(**kwargs)
23

24
    def create_gemini_messages(self, messages) -> json:
25
        has_history = any(msg['role'] == 'assistant' for msg in messages)
26
        gemini_msg = []
27

28
        for msg in messages:
29
            role = msg['role']
30
            content = msg['content']
31
            if role == 'system':
32
                continue
33
            if has_history:
34
                if role == 'assistant':
35
                    role = "model"
36
                transformed_msg = {"role": role, "parts": [{"text": content}]}
37
            else:
38
                if role == 'user':
39
                    transformed_msg = {"parts": [{"text": content}]}
40

41
            gemini_msg.append(transformed_msg)
42

43
        msg = dict(contents=gemini_msg)
44
        return msg
45

46
    def do_chat(self, params: ApiChatParams) -> Dict:
47
        params.load_config(self.model_names[0])
48
        data = self.create_gemini_messages(messages=params.messages)
49
        generationConfig = dict(
50
            temperature=params.temperature,
51
            topK=1,
52
            topP=1,
53
            maxOutputTokens=4096,
54
            stopSequences=[]
55
        )
56

57
        data['generationConfig'] = generationConfig
58
        url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent" + '?key=' + params.api_key
59
        headers = {
60
            'Content-Type': 'application/json',
61
        }
62
        if log_verbose:
63
            logger.info(f'{self.__class__.__name__}:url: {url}')
64
            logger.info(f'{self.__class__.__name__}:headers: {headers}')
65
            logger.info(f'{self.__class__.__name__}:data: {data}')
66

67
        text = ""
68
        json_string = ""
69
        timeout = httpx.Timeout(60.0)
70
        client = get_httpx_client(timeout=timeout)
71
        with client.stream("POST", url, headers=headers, json=data) as response:
72
            for line in response.iter_lines():
73
                line = line.strip()
74
                if not line or "[DONE]" in line:
75
                    continue
76

77
                json_string += line
78

79
            try:
80
                resp = json.loads(json_string)
81
                if 'candidates' in resp:
82
                    for candidate in resp['candidates']:
83
                        content = candidate.get('content', {})
84
                        parts = content.get('parts', [])
85
                        for part in parts:
86
                            if 'text' in part:
87
                                text += part['text']
88
                                yield {
89
                                    "error_code": 0,
90
                                    "text": text
91
                                }
92
                        print(text)
93
            except json.JSONDecodeError as e:
94
                print("Failed to decode JSON:", e)
95
                print("Invalid JSON string:", json_string)
96

97
    def get_embeddings(self, params):
98
        print("embedding")
99
        print(params)
100

101
    def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
102
        return conv.Conversation(
103
            name=self.model_names[0],
104
            system_message="You are a helpful, respectful and honest assistant.",
105
            messages=[],
106
            roles=["user", "assistant"],
107
            sep="\n### ",
108
            stop_str="###",
109
        )
110

111

112
if __name__ == "__main__":
113
    import uvicorn
114
    from server.utils import MakeFastAPIOffline
115
    from fastchat.serve.base_model_worker import app
116

117
    worker = GeminiWorker(
118
        controller_addr="http://127.0.0.1:20001",
119
        worker_addr="http://127.0.0.1:21012",
120
    )
121
    sys.modules["fastchat.serve.model_worker"].worker = worker
122
    MakeFastAPIOffline(app)
123
    uvicorn.run(app, port=21012)
124

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.