h2o-llmstudio
77 строк · 2.8 Кб
1import logging
2from typing import Any, Dict, List
3
4import numpy as np
5import pandas as pd
6import torch
7
8from llm_studio.src.datasets.text_causal_language_modeling_ds import (
9CustomDataset as CausalLMCustomDataset,
10)
11from llm_studio.src.datasets.text_utils import TEXT_SEPARATOR
12
13logger = logging.getLogger(__name__)
14
15
16class CustomDataset(CausalLMCustomDataset):
17def __init__(self, df: pd.DataFrame, cfg: Any, mode: str = "train"):
18assert (
19cfg.dataset.system_column == "None"
20), "RLHF is not compatible with system column."
21assert (
22cfg.dataset.limit_chained_samples is False
23), "RLHF is not compatible with limit_chained_samples."
24assert (
25cfg.dataset.mask_prompt_labels is True
26), "RLHF is not compatible with mask_prompt_labels."
27super().__init__(df, cfg, mode)
28
29def __getitem__(self, idx: int) -> Dict:
30"""Reads a single text observation."""
31sample = super().__getitem__(idx)
32sample["reward_model_prompt_text"] = TEXT_SEPARATOR.join(
33self.get_chained_prompt_text_list(idx)
34)
35return sample
36
37def get_labels(self, prompt_encodings, answer_encodings):
38if self.mode == "train": # no labels required for RLHF during training
39return dict()
40else:
41return super().get_labels(prompt_encodings, answer_encodings)
42
43def get_encodings(self, input_text_dict):
44system_encoding, prompt_encodings, answer_encodings = super().get_encodings(
45input_text_dict
46)
47# remove last ground truth answer,
48# as RLHF will generate the answer from the prompt
49answer_encodings[-1] = torch.empty(0)
50return system_encoding, prompt_encodings, answer_encodings
51
52def postprocess_batch_predictions(self, output: Dict) -> Dict:
53if "predicted_answer_ids" in output.keys():
54predicted_text = [
55self.tokenizer.decode(ids, skip_special_tokens=True).strip()
56for ids in output["predicted_answer_ids"]
57]
58
59output["predicted_text"] = np.array(predicted_text)
60output["predicted_answer_ids"] = output["predicted_answer_ids"].detach()
61return output
62
63def augment_data(self, encodings):
64return encodings
65
66def get_chained_prompt_text_list(self, idx: int) -> List[str]:
67text_dict = self.conversation_chain_handler[idx]
68chat_history = "".join(
69[
70prompt + TEXT_SEPARATOR + answer + TEXT_SEPARATOR
71for prompt, answer in zip(
72text_dict["prompts"][:-1], text_dict["answers"][:-1]
73)
74]
75)
76prompt_text = text_dict["systems"][0] + chat_history + text_dict["prompts"][-1]
77return prompt_text.split(TEXT_SEPARATOR)
78