h2o-llmstudio

Форк
0
/
text_rlhf_modeling_ds.py 
77 строк · 2.8 Кб
1
import logging
2
from typing import Any, Dict, List
3

4
import numpy as np
5
import pandas as pd
6
import torch
7

8
from llm_studio.src.datasets.text_causal_language_modeling_ds import (
9
    CustomDataset as CausalLMCustomDataset,
10
)
11
from llm_studio.src.datasets.text_utils import TEXT_SEPARATOR
12

13
logger = logging.getLogger(__name__)
14

15

16
class CustomDataset(CausalLMCustomDataset):
17
    def __init__(self, df: pd.DataFrame, cfg: Any, mode: str = "train"):
18
        assert (
19
            cfg.dataset.system_column == "None"
20
        ), "RLHF is not compatible with system column."
21
        assert (
22
            cfg.dataset.limit_chained_samples is False
23
        ), "RLHF is not compatible with limit_chained_samples."
24
        assert (
25
            cfg.dataset.mask_prompt_labels is True
26
        ), "RLHF is not compatible with mask_prompt_labels."
27
        super().__init__(df, cfg, mode)
28

29
    def __getitem__(self, idx: int) -> Dict:
30
        """Reads a single text observation."""
31
        sample = super().__getitem__(idx)
32
        sample["reward_model_prompt_text"] = TEXT_SEPARATOR.join(
33
            self.get_chained_prompt_text_list(idx)
34
        )
35
        return sample
36

37
    def get_labels(self, prompt_encodings, answer_encodings):
38
        if self.mode == "train":  # no labels required for RLHF during training
39
            return dict()
40
        else:
41
            return super().get_labels(prompt_encodings, answer_encodings)
42

43
    def get_encodings(self, input_text_dict):
44
        system_encoding, prompt_encodings, answer_encodings = super().get_encodings(
45
            input_text_dict
46
        )
47
        # remove last ground truth answer,
48
        # as RLHF will generate the answer from the prompt
49
        answer_encodings[-1] = torch.empty(0)
50
        return system_encoding, prompt_encodings, answer_encodings
51

52
    def postprocess_batch_predictions(self, output: Dict) -> Dict:
53
        if "predicted_answer_ids" in output.keys():
54
            predicted_text = [
55
                self.tokenizer.decode(ids, skip_special_tokens=True).strip()
56
                for ids in output["predicted_answer_ids"]
57
            ]
58

59
            output["predicted_text"] = np.array(predicted_text)
60
            output["predicted_answer_ids"] = output["predicted_answer_ids"].detach()
61
        return output
62

63
    def augment_data(self, encodings):
64
        return encodings
65

66
    def get_chained_prompt_text_list(self, idx: int) -> List[str]:
67
        text_dict = self.conversation_chain_handler[idx]
68
        chat_history = "".join(
69
            [
70
                prompt + TEXT_SEPARATOR + answer + TEXT_SEPARATOR
71
                for prompt, answer in zip(
72
                    text_dict["prompts"][:-1], text_dict["answers"][:-1]
73
                )
74
            ]
75
        )
76
        prompt_text = text_dict["systems"][0] + chat_history + text_dict["prompts"][-1]
77
        return prompt_text.split(TEXT_SEPARATOR)
78

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.