h2o-llmstudio
226 строк · 8.9 Кб
1import logging2from typing import Dict, List3
4import numpy as np5
6from llm_studio.src.datasets.text_utils import get_texts7from llm_studio.src.utils.utils import PatchedAttribute8
9logger = logging.getLogger(__name__)10
11
12class ConversationChainHandler:13"""14This class partitions the dataset into chains of conversations.
15Each chain is comprised of a list of conversation rounds.
16Each round within a conversation is represented as a triplet:
17(system, prompt, answer).
18
19The resulting structure of the chains is conditional on
20the DataFrame's structure and configuration:
21
22- Without a 'parent_id' in the DataFrame, each conversation chain is a single round.
23So, for every `i`-th row in the DataFrame, 0 <= `i` < len(df),
24the chain would look like: [(system_i, prompt_i, answer_i)]
25
26- With a 'parent_id' in the DataFrame and
27if `cfg.dataset.limit_chained_samples` is set to False,
28each chain encapsulates all preceding conversations
29for every `i`-th row in the DataFrame,
300 <= `i` < len(df).
31The resultant chain would take shape:
32[(system_start_conversation_i,
33prompt_start_conversation_i,
34answer_start_conversation_i),
35...,
36(system_i, prompt_i, answer_i)]
37
38- With a 'parent_id' in the DataFrame and
39if `cfg.dataset.limit_chained_samples` is set to True,
40each conversation chain incorporates only full conversations.
41The chain hence condenses into:
42[(system_start_conversation_i,
43prompt_start_conversation_i,
44answer_start_conversation_i),
45...,
46(system_end_conversation_i,
47prompt_end_conversation_i,
48answer_end_conversation_i)]
49where `i` represents complete conversations only.
50"""
51
52def __init__(53self,54df,55cfg,56):57# Do not set self.cfg = cfg, as ConversationChainHandler58# will be used with PatchedAttribute context manager.59self.conversation_chain_ids = self.get_conversation_chain_ids(cfg, df)60self.prompts = get_texts(df, cfg, separator="")61self.answers = self.get_answers(df, cfg)62self.systems = self.get_systems(cfg, df)63
64def get_conversation_chain_ids(self, cfg, df):65"""66Gets the conversation chain IDs for the given DataFrame.
67E.g. if conversation_chain_ids = [[13, 44, 8], ...],
68then the first conversation chain consists of
69[df.iloc[13], df.iloc[44], df.iloc[8]]
70with
71- df.iloc[13] denotes the first round of the conversation
72- df.iloc[44] denotes the second round of the conversation
73- df.iloc[8] denotes the end of the conversation
74if limit_chained_samples is True, df.iloc[13] will have no parent_id,
75i.e. it is the start of the conversation.
76"""
77if (78cfg.dataset.parent_id_column in ["None", None]79# Handle case where train Dataframe has conversation chains,80# but val Dataframe does not81or cfg.dataset.parent_id_column not in df.columns82):83# no parent id column, so each triplet (system_i, prompt_i, answer_i)84# is a conversation chain85return [[idx] for idx in range(len(df))]86
87assert "id" in df.columns, (88f"id column required for conversation chaining, "89f"DataFrame only has {df.columns}."90)91# sample and parent ids can have any dtype, such as str, int, float, etc.92# id column can be int, while parent_id column can be float93# (as some values are NaN) so we cast id to the same dtype94sample_ids = df["id"].astype(df[cfg.dataset.parent_id_column].dtype).tolist()95parent_ids = df[cfg.dataset.parent_id_column].tolist()96# Some datasets may include parent ids that are not in the dataset.97sample_ids_set = set(sample_ids)98parent_ids = [idx if idx in sample_ids_set else "None" for idx in parent_ids]99
100id2parent_id = {101idx: parent_id102for idx, parent_id in zip(sample_ids, parent_ids)103if parent_id not in [None, "None"]104and (105not isinstance(parent_id, float)106or (not np.isnan(parent_id) and not np.isinf(parent_id))107)108}109if cfg.dataset.limit_chained_samples:110# end id == id is not a parent id of another conversation id111valid_parent_ids = set(id2parent_id.values())112conversation_end_ids = [113idx for idx in sample_ids if idx not in valid_parent_ids114]115else:116conversation_end_ids = sample_ids117conversation_chain_ids = [118self.get_conversation_ids(id2parent_id, conversation_end_id)119for conversation_end_id in conversation_end_ids120]121# map from df["id"] to enumeration index122dataframeid2idx = {id: idx for idx, id in enumerate(sample_ids)}123conversation_chain_ids = [124[dataframeid2idx[conversation_id] for conversation_id in conversation_ids]125for conversation_ids in conversation_chain_ids126]127return conversation_chain_ids128
129def get_answers(self, df, cfg):130answer_column = cfg.dataset.answer_column131if answer_column in df.columns:132answers = df[answer_column].astype(str).tolist()133else:134answers = ["" for _ in range(len(self.prompts))]135return answers136
137def get_systems(self, cfg, df):138if cfg.dataset.system_column != "None":139if cfg.dataset.system_column not in df.columns:140logger.warning(141f"System column {cfg.dataset.system_column} not found."142f"Disabling functionality."143)144systems = ["" for _ in range(len(self.prompts))]145else:146systems = df[cfg.dataset.system_column].astype(str).tolist()147else:148systems = ["" for _ in range(len(self.prompts))]149return systems150
151@staticmethod152def get_conversation_ids(id2parent_id, end_id):153"""154Gets the conversation chain for a given starting conversation ID.
155Args:
156id2parent_id: A dictionary containing the mapping of IDs
157to its previous parent ID.
158end_id: The ID of the end of the conversation in the chain.
159Returns:
160A list of conversation IDs representing the conversation chain.
161The chain is ordered from the first conversation id to end_id in the chain.
162"""
163# prevent infinite loops in case164# of circular parent chains (dataframe issue)165loop_counter = 0166
167conversation_chain_ids = [end_id]168parent_id = end_id169while parent_id in id2parent_id:170loop_counter += 1171
172parent_id = id2parent_id[parent_id]173conversation_chain_ids = [parent_id] + conversation_chain_ids174if loop_counter > 1000:175raise ValueError(176f"Parent chain of sample with idx {end_id} "177f"exceeds max loop count of 1000. "178f"Please ensure that parent chain is not circular."179)180return conversation_chain_ids181
182def __len__(self):183return len(self.conversation_chain_ids)184
185def __getitem__(self, idx):186"""187Gets a single conversation chain.
188The conversation may be:
189- a single (system, prompt, answer) round,
190if cfg.dataset.parent_id_column == "None" or
191there is no parent_id for the conversation
192- a conversation potentially starting somewhere in
193the middle of the conversation, if the conversation
194is chained and limit_chained_samples is set to False
195- always a complete conversation, if the conversation is chained
196and limit_chained_samples is True
197
198"""
199prompts = [self.prompts[i] for i in self.conversation_chain_ids[idx]]200answers = [self.answers[i] for i in self.conversation_chain_ids[idx]]201systems = [self.systems[i] for i in self.conversation_chain_ids[idx]]202return {203"prompts": prompts,204"answers": answers,205"systems": systems,206}207
208def get_conversation_end_ids(self):209"""210Gets the end conversation IDs for each conversation chain.
211"""
212return [213conversation_chain[-1] for conversation_chain in self.conversation_chain_ids214]215
216
217def get_conversation_chains(218df, cfg, limit_chained_samples=True219) -> List[Dict[str, List[str]]]:220with PatchedAttribute(cfg.dataset, "limit_chained_samples", limit_chained_samples):221conversation_chain_handler = ConversationChainHandler(df, cfg)222conversations = [223conversation
224for conversation in conversation_chain_handler # type: ignore[attr-defined]225]226return conversations227