h2o-llmstudio

Форк
0
/
text_causal_language_modeling_ds.py 
517 строк · 19.6 Кб
1
import codecs
2
import collections.abc
3
import logging
4
from typing import Any, Dict, List, Tuple, Union
5

6
import numpy as np
7
import pandas as pd
8
import torch
9
from torch.utils.data import Dataset
10

11
from llm_studio.src.datasets.conversation_chain_handler import ConversationChainHandler
12
from llm_studio.src.datasets.text_utils import get_tokenizer
13

14
logger = logging.getLogger(__name__)
15

16

17
class CustomDataset(Dataset):
18
    """Dataset for Causal Language modeling."""
19

20
    def __init__(self, df: pd.DataFrame, cfg: Any, mode: str = "train"):
21
        """
22
        Args:
23
            df: input DataFrame
24
            cfg: config with all the hyperparameters
25
            mode: dataset mode. One of {"train", "validation"}
26
        """
27
        self.cfg = cfg
28
        self.mode = mode
29
        self.df = df.copy()
30
        self.tokenizer = get_tokenizer(self.cfg)
31
        self.conversation_chain_handler = ConversationChainHandler(self.df, cfg)
32

33
    def __len__(self) -> int:
34
        return len(self.conversation_chain_handler)
35

36
    def __getitem__(self, idx: int) -> Dict:
37
        """Reads a single text observation."""
38
        input_text_dict = self.conversation_chain_handler[idx]
39
        input_text_dict["systems"] = [
40
            self.parse_system(self.cfg, system) for system in input_text_dict["systems"]
41
        ]
42
        input_text_dict["prompts"] = [
43
            self.parse_prompt(self.cfg, prompt) for prompt in input_text_dict["prompts"]
44
        ]
45

46
        sample = dict()
47
        system_encoding, prompt_encodings, answer_encodings = self.get_encodings(
48
            input_text_dict=input_text_dict
49
        )
50

51
        input_ids = torch.cat(
52
            [
53
                torch.cat([prompt_encoding, answer_encoding])
54
                for prompt_encoding, answer_encoding in zip(
55
                    prompt_encodings, answer_encodings
56
                )
57
            ]
58
        )
59

60
        sample.update(self.get_labels(prompt_encodings, answer_encodings))
61
        sample.update(
62
            self.pad_tokens(
63
                input_ids,
64
                attention_mask=torch.ones_like(input_ids),
65
                max_length=self.cfg.tokenizer.max_length,
66
                pad_token_id=self.tokenizer.pad_token_id,
67
            )
68
        )
69

70
        # get answer encodings
71
        sample.update(
72
            self.pad_tokens(
73
                answer_encodings[-1],
74
                attention_mask=torch.ones_like(answer_encodings[-1]),
75
                max_length=self.cfg.tokenizer.max_length_answer,
76
                pad_token_id=self.tokenizer.pad_token_id,
77
                direction="right",
78
                prefix="answer_",
79
            )
80
        )
81

82
        # Remove last answer from encoding to create the prompt for inference
83
        answer_encodings[-1] = torch.empty(0)
84
        prompt_input_ids = torch.cat(
85
            [
86
                torch.cat([prompt_encoding, answer_encoding])
87
                for prompt_encoding, answer_encoding in zip(
88
                    prompt_encodings, answer_encodings
89
                )
90
            ]
91
        )
92
        sample.update(
93
            self.pad_tokens(
94
                prompt_input_ids,
95
                attention_mask=torch.ones_like(prompt_input_ids),
96
                max_length=self.cfg.tokenizer.max_length,
97
                pad_token_id=self.tokenizer.pad_token_id,
98
                prefix="prompt_",
99
            )
100
        )
101

102
        # make sure system encoding is always prepended if max_length exceeded
103
        if sample["input_ids"][0] != self.tokenizer.pad_token_id:
104
            sample["input_ids"][: len(system_encoding)] = system_encoding
105
            if self.cfg.dataset.mask_prompt_labels and "labels" in sample.keys():
106
                sample["labels"][: len(system_encoding)] = -100
107
        if sample["prompt_input_ids"][0] != self.tokenizer.pad_token_id:
108
            sample["prompt_input_ids"][: len(system_encoding)] = system_encoding
109

110
        return sample
111

112
    @staticmethod
113
    def parse_prompt(cfg: Any, prompt: str):
114
        prompt = (
115
            f"{codecs.decode(cfg.dataset.text_prompt_start, 'unicode_escape')}{prompt}"
116
        )
117
        if cfg.dataset.add_eos_token_to_prompt:
118
            prompt += cfg._tokenizer_eos_token
119
        prompt = (
120
            f"{prompt}"
121
            f"{codecs.decode(cfg.dataset.text_answer_separator, 'unicode_escape')}"
122
        )
123
        return prompt
124

125
    @staticmethod
126
    def parse_system(cfg: Any, system: str):
127
        # no system tokens if empty
128
        if system == "":
129
            return system
130
        system = (
131
            f"{codecs.decode(cfg.dataset.text_system_start, 'unicode_escape')}{system}"
132
        )
133
        if cfg.dataset.add_eos_token_to_system:
134
            system += cfg._tokenizer_eos_token
135
        return system
136

137
    @staticmethod
138
    def batch_to_device(
139
        batch: Union[Dict, List, torch.Tensor], device: str
140
    ) -> Union[Dict, List, torch.Tensor, str]:
141
        """Function to send the batch to the device specified
142

143
        Args:
144
            batch: input batch
145
            device: device to send the data to
146
        Returns:
147
            batch with the elements on the device specified
148
        """
149
        if isinstance(batch, torch.Tensor):
150
            return batch.to(device)
151
        elif isinstance(batch, (list, tuple)) and all(
152
            isinstance(item, str) for item in batch
153
        ):
154
            # Do not move list of strings to device
155
            return batch
156
        elif isinstance(batch, collections.abc.Mapping):
157
            return {
158
                key: CustomDataset.batch_to_device(value, device)
159
                for key, value in batch.items()
160
            }
161
        elif isinstance(batch, collections.abc.Sequence):
162
            return [CustomDataset.batch_to_device(value, device) for value in batch]
163
        else:
164
            raise ValueError(f"Can not move {type(batch)} to device.")
165

166
    @staticmethod
167
    def preprocess_dataframe(df: pd.DataFrame, cfg: Any, mode: str) -> pd.DataFrame:
168
        """
169
        Preprocesses the input dataframe
170

171
        Args:
172
            df: the full training dataframe
173
            cfg: config
174
            mode: the mode. One of {"train", "validation"}
175
        Returns:
176
            the processed dataframe
177
        """
178

179
        def personalize(text):
180
            text = text.replace("Open Assistant", cfg.dataset.chatbot_name)
181
            text = text.replace("Open-Assistant", cfg.dataset.chatbot_name)
182
            text = text.replace("open-assistant", cfg.dataset.chatbot_name)
183
            text = text.replace("OpenAssistant", cfg.dataset.chatbot_name)
184
            text = text.replace("open assistant", cfg.dataset.chatbot_name)
185
            text = text.replace("Open Assistand", cfg.dataset.chatbot_name)
186
            text = text.replace("Open Assitant", cfg.dataset.chatbot_name)
187
            text = text.replace("Open Assistent", cfg.dataset.chatbot_name)
188
            text = text.replace("Open Assisstant", cfg.dataset.chatbot_name)
189
            text = text.replace("Open Assitent", cfg.dataset.chatbot_name)
190
            text = text.replace("Open Assitiant", cfg.dataset.chatbot_name)
191
            text = text.replace("Open Assistiant", cfg.dataset.chatbot_name)
192
            text = text.replace("Open Assitan ", cfg.dataset.chatbot_name + " ")
193
            text = text.replace("Open Assistan ", cfg.dataset.chatbot_name + " ")
194
            text = text.replace("Open Asistant", cfg.dataset.chatbot_name)
195
            text = text.replace("Open Assiant", cfg.dataset.chatbot_name)
196
            text = text.replace("Assistant", cfg.dataset.chatbot_name)
197
            text = text.replace("LAION AI", cfg.dataset.chatbot_author)
198
            text = text.replace("LAION-AI", cfg.dataset.chatbot_author)
199
            text = text.replace("LAION,", cfg.dataset.chatbot_author + ",")
200
            text = text.replace("LAION.ai", cfg.dataset.chatbot_author)
201
            text = text.replace("LAION.", cfg.dataset.chatbot_author + ".")
202
            text = text.replace("LAION", cfg.dataset.chatbot_author)
203
            return text
204

205
        if cfg.dataset.personalize:
206
            for prompt_col in cfg.dataset.prompt_column:
207
                df[prompt_col] = df[prompt_col].apply(personalize)
208
            df[cfg.dataset.answer_column] = df[cfg.dataset.answer_column].apply(
209
                personalize
210
            )
211

212
        return df
213

214
    def get_train_collate_fn(self):
215
        """
216
        Returns train batch collate function for the PyTorch Dataloader.
217
        By default returns None that uses the default PyTorch collate
218
        """
219

220
        return None
221

222
    def get_validation_collate_fn(self):
223
        """
224
        Return validation batch collate function for the PyTorch Dataloader.
225
        By default returns None that uses the default PyTorch collate
226
        """
227

228
        return None
229

230
    def postprocess_batch_predictions(self, output: Dict) -> Dict:
231
        if "predicted_answer_ids" in output.keys():
232
            predicted_text = [
233
                self.tokenizer.decode(ids, skip_special_tokens=True).strip()
234
                for ids in output["predicted_answer_ids"]
235
            ]
236

237
            output["predicted_text"] = np.array(predicted_text)
238
            del output["predicted_answer_ids"]
239
        return output
240

241
    @staticmethod
242
    def clean_output(
243
        output: Dict,
244
        cfg: Any,
245
    ):
246
        output["predicted_text"] = output["predicted_text"].tolist()
247
        for j in range(len(output["predicted_text"])):
248
            curr_text = output["predicted_text"][j].strip()
249
            for stop_token in cfg.tokenizer._stop_words:
250
                if curr_text.find(stop_token) != -1:
251
                    curr_text = curr_text[: curr_text.find(stop_token)]
252
            output["predicted_text"][j] = curr_text.strip()
253

254
        return output
255

256
    def postprocess_output(self, cfg, df: pd.DataFrame, output: Dict) -> Dict:
257
        if not cfg.prediction.metric == "Perplexity":
258
            output = self.clean_output(output, cfg)
259

260
        output["target_text"] = self.conversation_chain_handler.answers
261

262
        metric_func, _, _ = cfg.prediction.metric_class.get(cfg.prediction.metric)
263

264
        if "GPT" in cfg.prediction.metric:
265
            metrics, explanations = metric_func(
266
                cfg,
267
                output,
268
                df,
269
                raw_results=True,
270
            )
271
            output["explanations"] = explanations
272
        else:
273
            metrics = metric_func(
274
                cfg,
275
                output,
276
                df,
277
            )
278
        output["metrics"] = metrics
279

280
        return output
281

282
    def format_output(
283
        self, cfg, df: pd.DataFrame, output: Dict
284
    ) -> Tuple[Dict, pd.DataFrame]:
285
        output = {
286
            key: value
287
            for key, value in output.items()
288
            if key not in ["loss", "target", "losses"]
289
        }
290
        output.pop("target_text", None)
291

292
        # in case limit_chained_samples is True, only last answer is predicted
293
        end_conversation_ids = (
294
            self.conversation_chain_handler.get_conversation_end_ids()
295
        )
296

297
        if "predicted_text" in output.keys():
298
            output["predicted_text"] = np.array(output["predicted_text"])
299

300
        if "logits" in output.keys():
301
            output["logits"] = np.array(output["logits"].float())
302

303
        if isinstance(cfg.dataset.prompt_column, tuple):
304
            for col in cfg.dataset.prompt_column:
305
                output[col] = df.loc[end_conversation_ids, col].values
306
        else:
307
            output[cfg.dataset.prompt_column] = df.loc[
308
                end_conversation_ids, cfg.dataset.prompt_column
309
            ].values
310

311
        if "predicted_text" in output.keys():
312
            df[f"pred_{cfg.dataset.answer_column}"] = (
313
                "NO ANSWER GENERATED. "
314
                "ONLY LAST ANSWER OF A CONVERSATION IS PREDICTED."
315
            )
316
            df.loc[end_conversation_ids, f"pred_{cfg.dataset.answer_column}"] = output[
317
                "predicted_text"
318
            ]
319
        return output, df
320

321
    @classmethod
322
    def sanity_check(cls, df: pd.DataFrame, cfg: Any, mode: str = "train"):
323
        """
324
        Quick check whether Dataframe and configurations are correctly set.
325
        """
326
        if (
327
            cfg.dataset.parent_id_column is not None
328
            and cfg.dataset.parent_id_column in df.columns
329
            and "id" in df.columns
330
        ):
331
            assert (
332
                df[cfg.dataset.parent_id_column] != df["id"]
333
            ).all(), "Parent id column is the same as id column for some rows"
334
            assert (df[cfg.dataset.parent_id_column].fillna("") == "").sum() > 0, (
335
                "Did not find any conversation start. "
336
                "Please ensure that some parent ids are empty."
337
            )
338

339
        assert cfg.dataset.answer_column in df.columns, (
340
            f"Answer column {cfg.dataset.answer_column} not found in the "
341
            f"{mode} DataFrame."
342
        )
343
        assert df.shape[0] == df[[cfg.dataset.answer_column]].dropna().shape[0], (
344
            f"The {mode} DataFrame"
345
            f" column {cfg.dataset.answer_column}"
346
            " contains missing values."
347
        )
348
        if cfg.dataset.parent_id_column != "None":
349
            assert (
350
                "id" in df.columns
351
            ), "When using parent column, the dataframe requires an 'id' column. "
352

353
    def get_labels(self, prompt_encodings, answer_encodings):
354
        labels = torch.cat(
355
            [
356
                torch.cat([prompt_encoding, answer_encoding])
357
                for prompt_encoding, answer_encoding in zip(
358
                    prompt_encodings, answer_encodings
359
                )
360
            ]
361
        ).clone()
362

363
        if self.cfg.dataset.mask_prompt_labels:
364
            prompt_mask = torch.cat(
365
                [
366
                    torch.cat(
367
                        [
368
                            torch.ones_like(prompt_encoding),
369
                            torch.zeros_like(answer_encoding),
370
                        ]
371
                    )
372
                    for prompt_encoding, answer_encoding in zip(
373
                        prompt_encodings, answer_encodings
374
                    )
375
                ]
376
            ).to(torch.bool)
377
            labels.masked_fill_(prompt_mask, -100)
378
        if self.cfg.dataset.add_eos_token_to_answer:
379
            # eos_token may be equal to pad_token. Add the label back manually.
380
            labels[-1] = self.tokenizer.eos_token_id
381
        if self.cfg.tokenizer.max_length < len(labels):
382
            labels = labels[-self.cfg.tokenizer.max_length :]
383

384
        sample = dict(labels=torch.full((self.cfg.tokenizer.max_length,), -100))
385
        sample["labels"][-len(labels) :] = labels
386
        return sample
387

388
    def get_encodings(self, input_text_dict: Dict[str, List[str]]):
389
        """
390
        Get encodings for a single conversation history.
391
        Args:
392
            input_text_dict: A dictionary containing the input text for a single sample.
393
            Contains the keys "systems", "prompts", "answers".
394
            System may be an empty string.
395
        """
396
        encodings = [
397
            self._get_sample_encoding(system, prompt, answer)
398
            for idx, (system, prompt, answer) in enumerate(
399
                zip(
400
                    input_text_dict["systems"],
401
                    input_text_dict["prompts"],
402
                    input_text_dict["answers"],
403
                )
404
            )
405
        ]
406

407
        if self.mode == "train":
408
            encodings = self.augment_data(encodings)
409

410
        system_encoding = encodings[0][0]
411
        prompt_encodings = [encoding[1] for encoding in encodings]
412
        answer_encodings = [encoding[2] for encoding in encodings]
413
        # concatenate system encoding with root prompt encoding
414
        prompt_encodings[0] = torch.cat([system_encoding, prompt_encodings[0]])
415
        return (
416
            system_encoding,
417
            prompt_encodings,
418
            answer_encodings,
419
        )
420

421
    def augment_data(self, encodings):
422
        parent_encodings = encodings[:-1]
423
        # randomly skip parent
424
        parent_encodings = [
425
            encoding
426
            for idx, encoding in enumerate(parent_encodings)
427
            if np.random.random() > self.cfg.augmentation.skip_parent_probability
428
        ]
429
        # randomly replace parent with another parent
430
        if np.random.random() < self.cfg.augmentation.random_parent_probability:
431
            idx = np.random.randint(len(self.conversation_chain_handler.prompts))
432
            parent_encodings = [
433
                self._get_sample_encoding(
434
                    self.parse_system(
435
                        self.cfg, self.conversation_chain_handler.systems[idx]
436
                    ),
437
                    self.parse_prompt(
438
                        self.cfg, self.conversation_chain_handler.prompts[idx]
439
                    ),
440
                    self.conversation_chain_handler.answers[idx],
441
                )
442
            ] + parent_encodings[1:]
443
        encodings = parent_encodings + [encodings[-1]]
444
        return encodings
445

446
    def _get_sample_encoding(self, system: str, prompt: str, answer: str) -> List:
447
        if len(system) > 0:
448
            system_encoding = self.encode(
449
                self.tokenizer, system, self.cfg.tokenizer.max_length_prompt, "right"
450
            )["input_ids"]
451
        else:
452
            system_encoding = torch.empty(0)
453
        prompt_encoding = self.encode(
454
            self.tokenizer, prompt, self.cfg.tokenizer.max_length_prompt, "left"
455
        )["input_ids"]
456
        max_length_answer = self.cfg.tokenizer.max_length_answer - int(
457
            self.cfg.dataset.add_eos_token_to_answer
458
        )
459
        answer_encoding = self.encode(
460
            self.tokenizer, answer, max_length_answer, "right"
461
        )["input_ids"]
462
        if self.cfg.dataset.add_eos_token_to_answer:
463
            answer_encoding = torch.cat(
464
                [
465
                    answer_encoding,
466
                    torch.Tensor([self.tokenizer.eos_token_id]),
467
                ],
468
                dim=0,
469
            )
470

471
        return [system_encoding, prompt_encoding, answer_encoding]
472

473
    @staticmethod
474
    def pad_tokens(
475
        input_ids,
476
        attention_mask,
477
        max_length,
478
        pad_token_id,
479
        direction="left",
480
        prefix="",
481
    ):
482
        sample = {}
483

484
        if max_length < len(input_ids):
485
            input_ids = input_ids[-max_length:]
486
            attention_mask = attention_mask[-max_length:]
487

488
        if len(input_ids) > 0:
489
            if direction == "left":
490
                sample[f"{prefix}input_ids"] = torch.full((max_length,), pad_token_id)
491
                sample[f"{prefix}input_ids"][-len(input_ids) :] = input_ids
492
                sample[f"{prefix}attention_mask"] = torch.zeros(max_length)
493
                sample[f"{prefix}attention_mask"][-len(input_ids) :] = attention_mask
494
            else:
495
                sample[f"{prefix}input_ids"] = torch.full((max_length,), pad_token_id)
496
                sample[f"{prefix}input_ids"][: len(input_ids)] = input_ids
497
                sample[f"{prefix}attention_mask"] = torch.zeros(max_length)
498
                sample[f"{prefix}attention_mask"][: len(input_ids)] = attention_mask
499
        else:
500
            # Pad everything if empty (continued pretraining)
501
            sample[f"{prefix}input_ids"] = torch.full((max_length,), pad_token_id)
502
            sample[f"{prefix}attention_mask"] = torch.zeros(max_length)
503

504
        return sample
505

506
    @staticmethod
507
    def encode(tokenizer, text: str, max_length: int, truncation_side: str) -> Dict:
508
        encodings = tokenizer(text, return_tensors="pt", add_special_tokens=False)
509
        encodings["input_ids"] = encodings["input_ids"][0]
510
        encodings["attention_mask"] = encodings["attention_mask"][0]
511
        if truncation_side == "right":
512
            encodings["input_ids"] = encodings["input_ids"][:max_length]
513
            encodings["attention_mask"] = encodings["attention_mask"][:max_length]
514
        else:
515
            encodings["input_ids"] = encodings["input_ids"][-max_length:]
516
            encodings["attention_mask"] = encodings["attention_mask"][-max_length:]
517
        return encodings
518

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.