colossalai

Форк
0
89 строк · 2.5 Кб
1
#    Copyright 2023 lm-sys@FastChat
2
#
3
#    Licensed under the Apache License, Version 2.0 (the "License");
4
#    you may not use this file except in compliance with the License.
5
#    You may obtain a copy of the License at
6
#
7
#        http://www.apache.org/licenses/LICENSE-2.0
8
#
9
#    Unless required by applicable law or agreed to in writing, software
10
#    distributed under the License is distributed on an "AS IS" BASIS,
11
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
#    See the License for the specific language governing permissions and
13
#    limitations under the License.
14

15
import dataclasses
16
from enum import Enum, auto
17
from typing import List
18

19

20
class SeparatorStyle(Enum):
21
    ADD_EOS_TOKEN = auto()
22

23

24
@dataclasses.dataclass
25
class Conversation:
26
    system: str
27
    roles: List[str]
28
    messages: List[List[str]]
29
    offset: int
30
    sep_style: SeparatorStyle = SeparatorStyle.ADD_EOS_TOKEN
31
    sep: str = "</s>"
32

33
    skip_next: bool = False
34

35
    def get_prompt(self):
36
        if self.sep_style == SeparatorStyle.ADD_EOS_TOKEN:
37
            ret = self.system
38
            for role, message in self.messages:
39
                if message:
40
                    ret += role + ": " + message + self.sep
41
                else:
42
                    ret += role + ": "
43
            return ret
44
        else:
45
            raise ValueError(f"Invalid style: {self.sep_style}")
46

47
    def append_message(self, role, message):
48
        self.messages.append([role, message])
49

50
    def to_gradio_chatbot(self):
51
        ret = []
52
        for i, (role, msg) in enumerate(self.messages[self.offset :]):
53
            if i % 2 == 0:
54
                ret.append([msg, None])
55
            else:
56
                ret[-1][-1] = msg
57
        return ret
58

59
    def copy(self):
60
        return Conversation(
61
            system=self.system,
62
            roles=self.roles,
63
            messages=[[x, y] for x, y in self.messages],
64
            offset=self.offset,
65
            sep_style=self.sep_style,
66
            sep=self.sep,
67
        )
68

69
    def dict(self):
70
        return {
71
            "system": self.system,
72
            "roles": self.roles,
73
            "messages": self.messages,
74
            "offset": self.offset,
75
            "sep": self.sep,
76
        }
77

78

79
conv = Conversation(
80
    system="A chat between a curious human and an artificial intelligence assistant. "
81
    "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
82
    roles=("Human", "Assistant"),
83
    messages=(),
84
    offset=0,
85
    sep_style=SeparatorStyle.ADD_EOS_TOKEN,
86
    sep="</s>",
87
)
88

89
default_conversation = conv
90

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.