h2o-llmstudio

Форк
0
/
conversation_chain_handler.py 
226 строк · 8.9 Кб
1
import logging
2
from typing import Dict, List
3

4
import numpy as np
5

6
from llm_studio.src.datasets.text_utils import get_texts
7
from llm_studio.src.utils.utils import PatchedAttribute
8

9
logger = logging.getLogger(__name__)
10

11

12
class ConversationChainHandler:
13
    """
14
    This class partitions the dataset into chains of conversations.
15
    Each chain is comprised of a list of conversation rounds.
16
    Each round within a conversation is represented as a triplet:
17
     (system, prompt, answer).
18

19
    The resulting structure of the chains is conditional on
20
    the DataFrame's structure and configuration:
21

22
    - Without a 'parent_id' in the DataFrame, each conversation chain is a single round.
23
     So, for every `i`-th row in the DataFrame, 0 <= `i` < len(df),
24
     the chain would look like: [(system_i, prompt_i, answer_i)]
25

26
    - With a 'parent_id' in the DataFrame and
27
      if `cfg.dataset.limit_chained_samples` is set to False,
28
      each chain encapsulates all preceding conversations
29
      for every `i`-th row in the DataFrame,
30
      0 <= `i` < len(df).
31
      The resultant chain would take shape:
32
          [(system_start_conversation_i,
33
            prompt_start_conversation_i,
34
            answer_start_conversation_i),
35
           ...,
36
           (system_i, prompt_i, answer_i)]
37

38
    - With a 'parent_id' in the DataFrame and
39
      if `cfg.dataset.limit_chained_samples` is set to True,
40
      each conversation chain incorporates only full conversations.
41
      The chain hence condenses into:
42
          [(system_start_conversation_i,
43
            prompt_start_conversation_i,
44
            answer_start_conversation_i),
45
           ...,
46
          (system_end_conversation_i,
47
           prompt_end_conversation_i,
48
           answer_end_conversation_i)]
49
      where `i` represents complete conversations only.
50
    """
51

52
    def __init__(
53
        self,
54
        df,
55
        cfg,
56
    ):
57
        # Do not set self.cfg = cfg, as ConversationChainHandler
58
        # will be used with PatchedAttribute context manager.
59
        self.conversation_chain_ids = self.get_conversation_chain_ids(cfg, df)
60
        self.prompts = get_texts(df, cfg, separator="")
61
        self.answers = self.get_answers(df, cfg)
62
        self.systems = self.get_systems(cfg, df)
63

64
    def get_conversation_chain_ids(self, cfg, df):
65
        """
66
        Gets the conversation chain IDs for the given DataFrame.
67
        E.g. if conversation_chain_ids = [[13, 44, 8], ...],
68
        then the first conversation chain consists of
69
        [df.iloc[13], df.iloc[44], df.iloc[8]]
70
        with
71
            - df.iloc[13] denotes the first round of the conversation
72
            - df.iloc[44] denotes the second round of the conversation
73
            - df.iloc[8] denotes the end of the conversation
74
        if limit_chained_samples is True, df.iloc[13] will have no parent_id,
75
        i.e. it is the start of the conversation.
76
        """
77
        if (
78
            cfg.dataset.parent_id_column in ["None", None]
79
            # Handle case where train Dataframe has conversation chains,
80
            # but val Dataframe does not
81
            or cfg.dataset.parent_id_column not in df.columns
82
        ):
83
            # no parent id column, so each triplet (system_i, prompt_i, answer_i)
84
            # is a conversation chain
85
            return [[idx] for idx in range(len(df))]
86

87
        assert "id" in df.columns, (
88
            f"id column required for conversation chaining, "
89
            f"DataFrame only has {df.columns}."
90
        )
91
        # sample and parent ids can have any dtype, such as str, int, float, etc.
92
        # id column can be int, while parent_id column can be float
93
        # (as some values are NaN) so we cast id to the same dtype
94
        sample_ids = df["id"].astype(df[cfg.dataset.parent_id_column].dtype).tolist()
95
        parent_ids = df[cfg.dataset.parent_id_column].tolist()
96
        # Some datasets may include parent ids that are not in the dataset.
97
        sample_ids_set = set(sample_ids)
98
        parent_ids = [idx if idx in sample_ids_set else "None" for idx in parent_ids]
99

100
        id2parent_id = {
101
            idx: parent_id
102
            for idx, parent_id in zip(sample_ids, parent_ids)
103
            if parent_id not in [None, "None"]
104
            and (
105
                not isinstance(parent_id, float)
106
                or (not np.isnan(parent_id) and not np.isinf(parent_id))
107
            )
108
        }
109
        if cfg.dataset.limit_chained_samples:
110
            # end id == id is not a parent id of another conversation id
111
            valid_parent_ids = set(id2parent_id.values())
112
            conversation_end_ids = [
113
                idx for idx in sample_ids if idx not in valid_parent_ids
114
            ]
115
        else:
116
            conversation_end_ids = sample_ids
117
        conversation_chain_ids = [
118
            self.get_conversation_ids(id2parent_id, conversation_end_id)
119
            for conversation_end_id in conversation_end_ids
120
        ]
121
        # map from df["id"] to enumeration index
122
        dataframeid2idx = {id: idx for idx, id in enumerate(sample_ids)}
123
        conversation_chain_ids = [
124
            [dataframeid2idx[conversation_id] for conversation_id in conversation_ids]
125
            for conversation_ids in conversation_chain_ids
126
        ]
127
        return conversation_chain_ids
128

129
    def get_answers(self, df, cfg):
130
        answer_column = cfg.dataset.answer_column
131
        if answer_column in df.columns:
132
            answers = df[answer_column].astype(str).tolist()
133
        else:
134
            answers = ["" for _ in range(len(self.prompts))]
135
        return answers
136

137
    def get_systems(self, cfg, df):
138
        if cfg.dataset.system_column != "None":
139
            if cfg.dataset.system_column not in df.columns:
140
                logger.warning(
141
                    f"System column {cfg.dataset.system_column} not found."
142
                    f"Disabling functionality."
143
                )
144
                systems = ["" for _ in range(len(self.prompts))]
145
            else:
146
                systems = df[cfg.dataset.system_column].astype(str).tolist()
147
        else:
148
            systems = ["" for _ in range(len(self.prompts))]
149
        return systems
150

151
    @staticmethod
152
    def get_conversation_ids(id2parent_id, end_id):
153
        """
154
        Gets the conversation chain for a given starting conversation ID.
155
        Args:
156
            id2parent_id: A dictionary containing the mapping of IDs
157
            to its previous parent ID.
158
            end_id: The ID of the end of the conversation in the chain.
159
        Returns:
160
            A list of conversation IDs representing the conversation chain.
161
            The chain is ordered from the first conversation id to end_id in the chain.
162
        """
163
        # prevent infinite loops in case
164
        # of circular parent chains (dataframe issue)
165
        loop_counter = 0
166

167
        conversation_chain_ids = [end_id]
168
        parent_id = end_id
169
        while parent_id in id2parent_id:
170
            loop_counter += 1
171

172
            parent_id = id2parent_id[parent_id]
173
            conversation_chain_ids = [parent_id] + conversation_chain_ids
174
            if loop_counter > 1000:
175
                raise ValueError(
176
                    f"Parent chain of sample with idx {end_id} "
177
                    f"exceeds max loop count of 1000. "
178
                    f"Please ensure that parent chain is not circular."
179
                )
180
        return conversation_chain_ids
181

182
    def __len__(self):
183
        return len(self.conversation_chain_ids)
184

185
    def __getitem__(self, idx):
186
        """
187
        Gets a single conversation chain.
188
        The conversation may be:
189
        - a single (system, prompt, answer) round,
190
          if cfg.dataset.parent_id_column == "None" or
191
          there is no parent_id for the conversation
192
        - a conversation potentially starting somewhere in
193
          the middle of the conversation, if the conversation
194
          is chained and limit_chained_samples is set to False
195
        - always a complete conversation, if the conversation is chained
196
          and limit_chained_samples is True
197

198
        """
199
        prompts = [self.prompts[i] for i in self.conversation_chain_ids[idx]]
200
        answers = [self.answers[i] for i in self.conversation_chain_ids[idx]]
201
        systems = [self.systems[i] for i in self.conversation_chain_ids[idx]]
202
        return {
203
            "prompts": prompts,
204
            "answers": answers,
205
            "systems": systems,
206
        }
207

208
    def get_conversation_end_ids(self):
209
        """
210
        Gets the end conversation IDs for each conversation chain.
211
        """
212
        return [
213
            conversation_chain[-1] for conversation_chain in self.conversation_chain_ids
214
        ]
215

216

217
def get_conversation_chains(
218
    df, cfg, limit_chained_samples=True
219
) -> List[Dict[str, List[str]]]:
220
    with PatchedAttribute(cfg.dataset, "limit_chained_samples", limit_chained_samples):
221
        conversation_chain_handler = ConversationChainHandler(df, cfg)
222
    conversations = [
223
        conversation
224
        for conversation in conversation_chain_handler  # type: ignore[attr-defined]
225
    ]
226
    return conversations
227

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.