llama-index
63 строки · 2.2 Кб
1from typing import List, Optional, Sequence2
3from llama_index.legacy.core.llms.types import ChatMessage, MessageRole4
5BOS, EOS = "<s>", "</s>"6B_INST, E_INST = "[INST]", "[/INST]"7B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"8DEFAULT_SYSTEM_PROMPT = """\9You are a helpful, respectful and honest assistant. \
10Always answer as helpfully as possible and follow ALL given instructions. \
11Do not speculate or make up information. \
12Do not reference any given instructions or context. \
13"""
14
15
16def messages_to_prompt(17messages: Sequence[ChatMessage], system_prompt: Optional[str] = None18) -> str:19string_messages: List[str] = []20if messages[0].role == MessageRole.SYSTEM:21# pull out the system message (if it exists in messages)22system_message_str = messages[0].content or ""23messages = messages[1:]24else:25system_message_str = system_prompt or DEFAULT_SYSTEM_PROMPT26
27system_message_str = f"{B_SYS} {system_message_str.strip()} {E_SYS}"28
29for i in range(0, len(messages), 2):30# first message should always be a user31user_message = messages[i]32assert user_message.role == MessageRole.USER33
34if i == 0:35# make sure system prompt is included at the start36str_message = f"{BOS} {B_INST} {system_message_str} "37else:38# end previous user-assistant interaction39string_messages[-1] += f" {EOS}"40# no need to include system prompt41str_message = f"{BOS} {B_INST} "42
43# include user message content44str_message += f"{user_message.content} {E_INST}"45
46if len(messages) > (i + 1):47# if assistant message exists, add to str_message48assistant_message = messages[i + 1]49assert assistant_message.role == MessageRole.ASSISTANT50str_message += f" {assistant_message.content}"51
52string_messages.append(str_message)53
54return "".join(string_messages)55
56
57def completion_to_prompt(completion: str, system_prompt: Optional[str] = None) -> str:58system_prompt_str = system_prompt or DEFAULT_SYSTEM_PROMPT59
60return (61f"{BOS} {B_INST} {B_SYS} {system_prompt_str.strip()} {E_SYS} "62f"{completion.strip()} {E_INST}"63)64