Langchain-Chatchat

Форк
0
/
message_repository.py 
72 строки · 2.1 Кб
1
from server.db.session import with_session
2
from typing import Dict, List
3
import uuid
4
from server.db.models.message_model import MessageModel
5

6

7
@with_session
8
def add_message_to_db(session, conversation_id: str, chat_type, query, response="", message_id=None,
9
                      metadata: Dict = {}):
10
    """
11
    新增聊天记录
12
    """
13
    if not message_id:
14
        message_id = uuid.uuid4().hex
15
    m = MessageModel(id=message_id, chat_type=chat_type, query=query, response=response,
16
                     conversation_id=conversation_id,
17
                     meta_data=metadata)
18
    session.add(m)
19
    session.commit()
20
    return m.id
21

22

23
@with_session
24
def update_message(session, message_id, response: str = None, metadata: Dict = None):
25
    """
26
    更新已有的聊天记录
27
    """
28
    m = get_message_by_id(message_id)
29
    if m is not None:
30
        if response is not None:
31
            m.response = response
32
        if isinstance(metadata, dict):
33
            m.meta_data = metadata
34
        session.add(m)
35
        session.commit()
36
        return m.id
37

38

39
@with_session
40
def get_message_by_id(session, message_id) -> MessageModel:
41
    """
42
    查询聊天记录
43
    """
44
    m = session.query(MessageModel).filter_by(id=message_id).first()
45
    return m
46

47

48
@with_session
49
def feedback_message_to_db(session, message_id, feedback_score, feedback_reason):
50
    """
51
    反馈聊天记录
52
    """
53
    m = session.query(MessageModel).filter_by(id=message_id).first()
54
    if m:
55
        m.feedback_score = feedback_score
56
        m.feedback_reason = feedback_reason
57
    session.commit()
58
    return m.id
59

60

61
@with_session
62
def filter_message(session, conversation_id: str, limit: int = 10):
63
    messages = (session.query(MessageModel).filter_by(conversation_id=conversation_id).
64
                # 用户最新的query 也会插入到db,忽略这个message record
65
                filter(MessageModel.response != '').
66
                # 返回最近的limit 条记录
67
                order_by(MessageModel.create_time.desc()).limit(limit).all())
68
    # 直接返回 List[MessageModel] 报错
69
    data = []
70
    for m in messages:
71
        data.append({"query": m.query, "response": m.response})
72
    return data
73

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.