h2o-llmstudio
517 строк · 19.6 Кб
1import codecs
2import collections.abc
3import logging
4from typing import Any, Dict, List, Tuple, Union
5
6import numpy as np
7import pandas as pd
8import torch
9from torch.utils.data import Dataset
10
11from llm_studio.src.datasets.conversation_chain_handler import ConversationChainHandler
12from llm_studio.src.datasets.text_utils import get_tokenizer
13
14logger = logging.getLogger(__name__)
15
16
17class CustomDataset(Dataset):
18"""Dataset for Causal Language modeling."""
19
20def __init__(self, df: pd.DataFrame, cfg: Any, mode: str = "train"):
21"""
22Args:
23df: input DataFrame
24cfg: config with all the hyperparameters
25mode: dataset mode. One of {"train", "validation"}
26"""
27self.cfg = cfg
28self.mode = mode
29self.df = df.copy()
30self.tokenizer = get_tokenizer(self.cfg)
31self.conversation_chain_handler = ConversationChainHandler(self.df, cfg)
32
33def __len__(self) -> int:
34return len(self.conversation_chain_handler)
35
36def __getitem__(self, idx: int) -> Dict:
37"""Reads a single text observation."""
38input_text_dict = self.conversation_chain_handler[idx]
39input_text_dict["systems"] = [
40self.parse_system(self.cfg, system) for system in input_text_dict["systems"]
41]
42input_text_dict["prompts"] = [
43self.parse_prompt(self.cfg, prompt) for prompt in input_text_dict["prompts"]
44]
45
46sample = dict()
47system_encoding, prompt_encodings, answer_encodings = self.get_encodings(
48input_text_dict=input_text_dict
49)
50
51input_ids = torch.cat(
52[
53torch.cat([prompt_encoding, answer_encoding])
54for prompt_encoding, answer_encoding in zip(
55prompt_encodings, answer_encodings
56)
57]
58)
59
60sample.update(self.get_labels(prompt_encodings, answer_encodings))
61sample.update(
62self.pad_tokens(
63input_ids,
64attention_mask=torch.ones_like(input_ids),
65max_length=self.cfg.tokenizer.max_length,
66pad_token_id=self.tokenizer.pad_token_id,
67)
68)
69
70# get answer encodings
71sample.update(
72self.pad_tokens(
73answer_encodings[-1],
74attention_mask=torch.ones_like(answer_encodings[-1]),
75max_length=self.cfg.tokenizer.max_length_answer,
76pad_token_id=self.tokenizer.pad_token_id,
77direction="right",
78prefix="answer_",
79)
80)
81
82# Remove last answer from encoding to create the prompt for inference
83answer_encodings[-1] = torch.empty(0)
84prompt_input_ids = torch.cat(
85[
86torch.cat([prompt_encoding, answer_encoding])
87for prompt_encoding, answer_encoding in zip(
88prompt_encodings, answer_encodings
89)
90]
91)
92sample.update(
93self.pad_tokens(
94prompt_input_ids,
95attention_mask=torch.ones_like(prompt_input_ids),
96max_length=self.cfg.tokenizer.max_length,
97pad_token_id=self.tokenizer.pad_token_id,
98prefix="prompt_",
99)
100)
101
102# make sure system encoding is always prepended if max_length exceeded
103if sample["input_ids"][0] != self.tokenizer.pad_token_id:
104sample["input_ids"][: len(system_encoding)] = system_encoding
105if self.cfg.dataset.mask_prompt_labels and "labels" in sample.keys():
106sample["labels"][: len(system_encoding)] = -100
107if sample["prompt_input_ids"][0] != self.tokenizer.pad_token_id:
108sample["prompt_input_ids"][: len(system_encoding)] = system_encoding
109
110return sample
111
112@staticmethod
113def parse_prompt(cfg: Any, prompt: str):
114prompt = (
115f"{codecs.decode(cfg.dataset.text_prompt_start, 'unicode_escape')}{prompt}"
116)
117if cfg.dataset.add_eos_token_to_prompt:
118prompt += cfg._tokenizer_eos_token
119prompt = (
120f"{prompt}"
121f"{codecs.decode(cfg.dataset.text_answer_separator, 'unicode_escape')}"
122)
123return prompt
124
125@staticmethod
126def parse_system(cfg: Any, system: str):
127# no system tokens if empty
128if system == "":
129return system
130system = (
131f"{codecs.decode(cfg.dataset.text_system_start, 'unicode_escape')}{system}"
132)
133if cfg.dataset.add_eos_token_to_system:
134system += cfg._tokenizer_eos_token
135return system
136
137@staticmethod
138def batch_to_device(
139batch: Union[Dict, List, torch.Tensor], device: str
140) -> Union[Dict, List, torch.Tensor, str]:
141"""Function to send the batch to the device specified
142
143Args:
144batch: input batch
145device: device to send the data to
146Returns:
147batch with the elements on the device specified
148"""
149if isinstance(batch, torch.Tensor):
150return batch.to(device)
151elif isinstance(batch, (list, tuple)) and all(
152isinstance(item, str) for item in batch
153):
154# Do not move list of strings to device
155return batch
156elif isinstance(batch, collections.abc.Mapping):
157return {
158key: CustomDataset.batch_to_device(value, device)
159for key, value in batch.items()
160}
161elif isinstance(batch, collections.abc.Sequence):
162return [CustomDataset.batch_to_device(value, device) for value in batch]
163else:
164raise ValueError(f"Can not move {type(batch)} to device.")
165
166@staticmethod
167def preprocess_dataframe(df: pd.DataFrame, cfg: Any, mode: str) -> pd.DataFrame:
168"""
169Preprocesses the input dataframe
170
171Args:
172df: the full training dataframe
173cfg: config
174mode: the mode. One of {"train", "validation"}
175Returns:
176the processed dataframe
177"""
178
179def personalize(text):
180text = text.replace("Open Assistant", cfg.dataset.chatbot_name)
181text = text.replace("Open-Assistant", cfg.dataset.chatbot_name)
182text = text.replace("open-assistant", cfg.dataset.chatbot_name)
183text = text.replace("OpenAssistant", cfg.dataset.chatbot_name)
184text = text.replace("open assistant", cfg.dataset.chatbot_name)
185text = text.replace("Open Assistand", cfg.dataset.chatbot_name)
186text = text.replace("Open Assitant", cfg.dataset.chatbot_name)
187text = text.replace("Open Assistent", cfg.dataset.chatbot_name)
188text = text.replace("Open Assisstant", cfg.dataset.chatbot_name)
189text = text.replace("Open Assitent", cfg.dataset.chatbot_name)
190text = text.replace("Open Assitiant", cfg.dataset.chatbot_name)
191text = text.replace("Open Assistiant", cfg.dataset.chatbot_name)
192text = text.replace("Open Assitan ", cfg.dataset.chatbot_name + " ")
193text = text.replace("Open Assistan ", cfg.dataset.chatbot_name + " ")
194text = text.replace("Open Asistant", cfg.dataset.chatbot_name)
195text = text.replace("Open Assiant", cfg.dataset.chatbot_name)
196text = text.replace("Assistant", cfg.dataset.chatbot_name)
197text = text.replace("LAION AI", cfg.dataset.chatbot_author)
198text = text.replace("LAION-AI", cfg.dataset.chatbot_author)
199text = text.replace("LAION,", cfg.dataset.chatbot_author + ",")
200text = text.replace("LAION.ai", cfg.dataset.chatbot_author)
201text = text.replace("LAION.", cfg.dataset.chatbot_author + ".")
202text = text.replace("LAION", cfg.dataset.chatbot_author)
203return text
204
205if cfg.dataset.personalize:
206for prompt_col in cfg.dataset.prompt_column:
207df[prompt_col] = df[prompt_col].apply(personalize)
208df[cfg.dataset.answer_column] = df[cfg.dataset.answer_column].apply(
209personalize
210)
211
212return df
213
214def get_train_collate_fn(self):
215"""
216Returns train batch collate function for the PyTorch Dataloader.
217By default returns None that uses the default PyTorch collate
218"""
219
220return None
221
222def get_validation_collate_fn(self):
223"""
224Return validation batch collate function for the PyTorch Dataloader.
225By default returns None that uses the default PyTorch collate
226"""
227
228return None
229
230def postprocess_batch_predictions(self, output: Dict) -> Dict:
231if "predicted_answer_ids" in output.keys():
232predicted_text = [
233self.tokenizer.decode(ids, skip_special_tokens=True).strip()
234for ids in output["predicted_answer_ids"]
235]
236
237output["predicted_text"] = np.array(predicted_text)
238del output["predicted_answer_ids"]
239return output
240
241@staticmethod
242def clean_output(
243output: Dict,
244cfg: Any,
245):
246output["predicted_text"] = output["predicted_text"].tolist()
247for j in range(len(output["predicted_text"])):
248curr_text = output["predicted_text"][j].strip()
249for stop_token in cfg.tokenizer._stop_words:
250if curr_text.find(stop_token) != -1:
251curr_text = curr_text[: curr_text.find(stop_token)]
252output["predicted_text"][j] = curr_text.strip()
253
254return output
255
256def postprocess_output(self, cfg, df: pd.DataFrame, output: Dict) -> Dict:
257if not cfg.prediction.metric == "Perplexity":
258output = self.clean_output(output, cfg)
259
260output["target_text"] = self.conversation_chain_handler.answers
261
262metric_func, _, _ = cfg.prediction.metric_class.get(cfg.prediction.metric)
263
264if "GPT" in cfg.prediction.metric:
265metrics, explanations = metric_func(
266cfg,
267output,
268df,
269raw_results=True,
270)
271output["explanations"] = explanations
272else:
273metrics = metric_func(
274cfg,
275output,
276df,
277)
278output["metrics"] = metrics
279
280return output
281
282def format_output(
283self, cfg, df: pd.DataFrame, output: Dict
284) -> Tuple[Dict, pd.DataFrame]:
285output = {
286key: value
287for key, value in output.items()
288if key not in ["loss", "target", "losses"]
289}
290output.pop("target_text", None)
291
292# in case limit_chained_samples is True, only last answer is predicted
293end_conversation_ids = (
294self.conversation_chain_handler.get_conversation_end_ids()
295)
296
297if "predicted_text" in output.keys():
298output["predicted_text"] = np.array(output["predicted_text"])
299
300if "logits" in output.keys():
301output["logits"] = np.array(output["logits"].float())
302
303if isinstance(cfg.dataset.prompt_column, tuple):
304for col in cfg.dataset.prompt_column:
305output[col] = df.loc[end_conversation_ids, col].values
306else:
307output[cfg.dataset.prompt_column] = df.loc[
308end_conversation_ids, cfg.dataset.prompt_column
309].values
310
311if "predicted_text" in output.keys():
312df[f"pred_{cfg.dataset.answer_column}"] = (
313"NO ANSWER GENERATED. "
314"ONLY LAST ANSWER OF A CONVERSATION IS PREDICTED."
315)
316df.loc[end_conversation_ids, f"pred_{cfg.dataset.answer_column}"] = output[
317"predicted_text"
318]
319return output, df
320
321@classmethod
322def sanity_check(cls, df: pd.DataFrame, cfg: Any, mode: str = "train"):
323"""
324Quick check whether Dataframe and configurations are correctly set.
325"""
326if (
327cfg.dataset.parent_id_column is not None
328and cfg.dataset.parent_id_column in df.columns
329and "id" in df.columns
330):
331assert (
332df[cfg.dataset.parent_id_column] != df["id"]
333).all(), "Parent id column is the same as id column for some rows"
334assert (df[cfg.dataset.parent_id_column].fillna("") == "").sum() > 0, (
335"Did not find any conversation start. "
336"Please ensure that some parent ids are empty."
337)
338
339assert cfg.dataset.answer_column in df.columns, (
340f"Answer column {cfg.dataset.answer_column} not found in the "
341f"{mode} DataFrame."
342)
343assert df.shape[0] == df[[cfg.dataset.answer_column]].dropna().shape[0], (
344f"The {mode} DataFrame"
345f" column {cfg.dataset.answer_column}"
346" contains missing values."
347)
348if cfg.dataset.parent_id_column != "None":
349assert (
350"id" in df.columns
351), "When using parent column, the dataframe requires an 'id' column. "
352
353def get_labels(self, prompt_encodings, answer_encodings):
354labels = torch.cat(
355[
356torch.cat([prompt_encoding, answer_encoding])
357for prompt_encoding, answer_encoding in zip(
358prompt_encodings, answer_encodings
359)
360]
361).clone()
362
363if self.cfg.dataset.mask_prompt_labels:
364prompt_mask = torch.cat(
365[
366torch.cat(
367[
368torch.ones_like(prompt_encoding),
369torch.zeros_like(answer_encoding),
370]
371)
372for prompt_encoding, answer_encoding in zip(
373prompt_encodings, answer_encodings
374)
375]
376).to(torch.bool)
377labels.masked_fill_(prompt_mask, -100)
378if self.cfg.dataset.add_eos_token_to_answer:
379# eos_token may be equal to pad_token. Add the label back manually.
380labels[-1] = self.tokenizer.eos_token_id
381if self.cfg.tokenizer.max_length < len(labels):
382labels = labels[-self.cfg.tokenizer.max_length :]
383
384sample = dict(labels=torch.full((self.cfg.tokenizer.max_length,), -100))
385sample["labels"][-len(labels) :] = labels
386return sample
387
388def get_encodings(self, input_text_dict: Dict[str, List[str]]):
389"""
390Get encodings for a single conversation history.
391Args:
392input_text_dict: A dictionary containing the input text for a single sample.
393Contains the keys "systems", "prompts", "answers".
394System may be an empty string.
395"""
396encodings = [
397self._get_sample_encoding(system, prompt, answer)
398for idx, (system, prompt, answer) in enumerate(
399zip(
400input_text_dict["systems"],
401input_text_dict["prompts"],
402input_text_dict["answers"],
403)
404)
405]
406
407if self.mode == "train":
408encodings = self.augment_data(encodings)
409
410system_encoding = encodings[0][0]
411prompt_encodings = [encoding[1] for encoding in encodings]
412answer_encodings = [encoding[2] for encoding in encodings]
413# concatenate system encoding with root prompt encoding
414prompt_encodings[0] = torch.cat([system_encoding, prompt_encodings[0]])
415return (
416system_encoding,
417prompt_encodings,
418answer_encodings,
419)
420
421def augment_data(self, encodings):
422parent_encodings = encodings[:-1]
423# randomly skip parent
424parent_encodings = [
425encoding
426for idx, encoding in enumerate(parent_encodings)
427if np.random.random() > self.cfg.augmentation.skip_parent_probability
428]
429# randomly replace parent with another parent
430if np.random.random() < self.cfg.augmentation.random_parent_probability:
431idx = np.random.randint(len(self.conversation_chain_handler.prompts))
432parent_encodings = [
433self._get_sample_encoding(
434self.parse_system(
435self.cfg, self.conversation_chain_handler.systems[idx]
436),
437self.parse_prompt(
438self.cfg, self.conversation_chain_handler.prompts[idx]
439),
440self.conversation_chain_handler.answers[idx],
441)
442] + parent_encodings[1:]
443encodings = parent_encodings + [encodings[-1]]
444return encodings
445
446def _get_sample_encoding(self, system: str, prompt: str, answer: str) -> List:
447if len(system) > 0:
448system_encoding = self.encode(
449self.tokenizer, system, self.cfg.tokenizer.max_length_prompt, "right"
450)["input_ids"]
451else:
452system_encoding = torch.empty(0)
453prompt_encoding = self.encode(
454self.tokenizer, prompt, self.cfg.tokenizer.max_length_prompt, "left"
455)["input_ids"]
456max_length_answer = self.cfg.tokenizer.max_length_answer - int(
457self.cfg.dataset.add_eos_token_to_answer
458)
459answer_encoding = self.encode(
460self.tokenizer, answer, max_length_answer, "right"
461)["input_ids"]
462if self.cfg.dataset.add_eos_token_to_answer:
463answer_encoding = torch.cat(
464[
465answer_encoding,
466torch.Tensor([self.tokenizer.eos_token_id]),
467],
468dim=0,
469)
470
471return [system_encoding, prompt_encoding, answer_encoding]
472
473@staticmethod
474def pad_tokens(
475input_ids,
476attention_mask,
477max_length,
478pad_token_id,
479direction="left",
480prefix="",
481):
482sample = {}
483
484if max_length < len(input_ids):
485input_ids = input_ids[-max_length:]
486attention_mask = attention_mask[-max_length:]
487
488if len(input_ids) > 0:
489if direction == "left":
490sample[f"{prefix}input_ids"] = torch.full((max_length,), pad_token_id)
491sample[f"{prefix}input_ids"][-len(input_ids) :] = input_ids
492sample[f"{prefix}attention_mask"] = torch.zeros(max_length)
493sample[f"{prefix}attention_mask"][-len(input_ids) :] = attention_mask
494else:
495sample[f"{prefix}input_ids"] = torch.full((max_length,), pad_token_id)
496sample[f"{prefix}input_ids"][: len(input_ids)] = input_ids
497sample[f"{prefix}attention_mask"] = torch.zeros(max_length)
498sample[f"{prefix}attention_mask"][: len(input_ids)] = attention_mask
499else:
500# Pad everything if empty (continued pretraining)
501sample[f"{prefix}input_ids"] = torch.full((max_length,), pad_token_id)
502sample[f"{prefix}attention_mask"] = torch.zeros(max_length)
503
504return sample
505
506@staticmethod
507def encode(tokenizer, text: str, max_length: int, truncation_side: str) -> Dict:
508encodings = tokenizer(text, return_tensors="pt", add_special_tokens=False)
509encodings["input_ids"] = encodings["input_ids"][0]
510encodings["attention_mask"] = encodings["attention_mask"][0]
511if truncation_side == "right":
512encodings["input_ids"] = encodings["input_ids"][:max_length]
513encodings["attention_mask"] = encodings["attention_mask"][:max_length]
514else:
515encodings["input_ids"] = encodings["input_ids"][-max_length:]
516encodings["attention_mask"] = encodings["attention_mask"][-max_length:]
517return encodings
518