llama-index

Форк
0
124 строки · 4.0 Кб
1
import typing
2
from typing import Sequence, Union
3

4
from llama_index.legacy.core.llms.types import MessageRole
5
from llama_index.legacy.llms.base import (
6
    ChatMessage,
7
    ChatResponse,
8
    CompletionResponse,
9
)
10

11
if typing.TYPE_CHECKING:
12
    import google.ai.generativelanguage as glm
13
    import google.generativeai as genai
14

15

16
ROLES_TO_GEMINI = {
17
    MessageRole.USER: "user",
18
    MessageRole.ASSISTANT: "model",
19
    ## Gemini only has user and model roles. Put the rest in user role.
20
    MessageRole.SYSTEM: "user",
21
}
22
ROLES_FROM_GEMINI = {v: k for k, v in ROLES_TO_GEMINI.items()}
23

24

25
def _error_if_finished_early(candidate: "glm.Candidate") -> None:  # type: ignore[name-defined] # only until release
26
    if (finish_reason := candidate.finish_reason) > 1:  # 1=STOP (normally)
27
        reason = finish_reason.name
28

29
        # Safety reasons have more detail, so include that if we can.
30
        if finish_reason == 3:  # 3=Safety
31
            relevant_safety = list(
32
                filter(
33
                    lambda sr: sr.probability > 1,  # 1=Negligible
34
                    candidate.safety_ratings,
35
                )
36
            )
37
            reason += f" {relevant_safety}"
38

39
        raise RuntimeError(f"Response was terminated early: {reason}")
40

41

42
def completion_from_gemini_response(
43
    response: Union[
44
        "genai.types.GenerateContentResponse",
45
        "genai.types.AsyncGenerateContentResponse",
46
    ],
47
) -> CompletionResponse:
48
    top_candidate = response.candidates[0]
49
    _error_if_finished_early(top_candidate)
50

51
    raw = {
52
        **(type(top_candidate).to_dict(top_candidate)),
53
        **(type(response.prompt_feedback).to_dict(response.prompt_feedback)),
54
    }
55
    return CompletionResponse(text=response.text, raw=raw)
56

57

58
def chat_from_gemini_response(
59
    response: Union[
60
        "genai.types.GenerateContentResponse",
61
        "genai.types.AsyncGenerateContentResponse",
62
    ],
63
) -> ChatResponse:
64
    top_candidate = response.candidates[0]
65
    _error_if_finished_early(top_candidate)
66

67
    raw = {
68
        **(type(top_candidate).to_dict(top_candidate)),
69
        **(type(response.prompt_feedback).to_dict(response.prompt_feedback)),
70
    }
71
    role = ROLES_FROM_GEMINI[top_candidate.content.role]
72
    return ChatResponse(message=ChatMessage(role=role, content=response.text), raw=raw)
73

74

75
def chat_message_to_gemini(message: ChatMessage) -> "genai.types.ContentDict":
76
    """Convert ChatMessages to Gemini-specific history, including ImageDocuments."""
77
    parts = [message.content]
78
    if images := message.additional_kwargs.get("images"):
79
        try:
80
            import PIL
81

82
            parts += [PIL.Image.open(doc.resolve_image()) for doc in images]
83
        except ImportError:
84
            # This should have been caught earlier, but tell the user anyway.
85
            raise ValueError("Multi-modal support requires PIL.")
86

87
    return {
88
        "role": ROLES_TO_GEMINI[message.role],
89
        "parts": parts,
90
    }
91

92

93
def merge_neighboring_same_role_messages(
94
    messages: Sequence[ChatMessage],
95
) -> Sequence[ChatMessage]:
96
    # Gemini does not support multiple messages of the same role in a row, so we merge them
97
    merged_messages = []
98
    i = 0
99

100
    while i < len(messages):
101
        current_message = messages[i]
102
        # Initialize merged content with current message content
103
        merged_content = [current_message.content]
104

105
        # Check if the next message exists and has the same role
106
        while (
107
            i + 1 < len(messages)
108
            and ROLES_TO_GEMINI[messages[i + 1].role]
109
            == ROLES_TO_GEMINI[current_message.role]
110
        ):
111
            i += 1
112
            next_message = messages[i]
113
            merged_content.extend([next_message.content])
114

115
        # Create a new ChatMessage or similar object with merged content
116
        merged_message = ChatMessage(
117
            role=current_message.role,
118
            content="\n".join([str(msg_content) for msg_content in merged_content]),
119
            additional_kwargs=current_message.additional_kwargs,
120
        )
121
        merged_messages.append(merged_message)
122
        i += 1
123

124
    return merged_messages
125

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.