colossalai

Форк
0
88 строк · 3.2 Кб
1
from typing import Callable
2

3
from torch.utils.data import Dataset
4
from tqdm import tqdm
5

6
from .utils import is_rank_0
7

8

9
# Dahoas/rm-static
10
class RmStaticDataset(Dataset):
11
    """
12
    Dataset for reward model
13

14
    Args:
15
        dataset: dataset for reward model
16
        tokenizer: tokenizer for reward model
17
        max_length: max length of input
18
        special_token: special token at the end of sentence
19
    """
20

21
    def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
22
        super().__init__()
23
        self.end_token = tokenizer.eos_token if special_token is None else special_token
24

25
        chosen = [data["prompt"] + data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
26
        chosen_token = tokenizer(
27
            chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
28
        )
29
        self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
30

31
        reject = [data["prompt"] + data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
32
        reject_token = tokenizer(
33
            reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
34
        )
35
        self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
36

37
    def __len__(self):
38
        length = self.chosen["input_ids"].shape[0]
39
        return length
40

41
    def __getitem__(self, idx):
42
        return (
43
            self.chosen["input_ids"][idx],
44
            self.chosen["attention_mask"][idx],
45
            self.reject["input_ids"][idx],
46
            self.reject["attention_mask"][idx],
47
        )
48

49

50
# Anthropic/hh-rlhf
51
class HhRlhfDataset(Dataset):
52
    """
53
    Dataset for reward model
54

55
    Args:
56
        dataset: dataset for reward model
57
        tokenizer: tokenizer for reward model
58
        max_length: max length of input
59
        special_token: special token at the end of sentence
60
    """
61

62
    def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
63
        super().__init__()
64
        self.end_token = tokenizer.eos_token if special_token is None else special_token
65

66
        chosen = [data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
67
        chosen_token = tokenizer(
68
            chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
69
        )
70
        self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
71

72
        reject = [data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
73
        reject_token = tokenizer(
74
            reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
75
        )
76
        self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
77

78
    def __len__(self):
79
        length = self.chosen["input_ids"].shape[0]
80
        return length
81

82
    def __getitem__(self, idx):
83
        return (
84
            self.chosen["input_ids"][idx],
85
            self.chosen["attention_mask"][idx],
86
            self.reject["input_ids"][idx],
87
            self.reject["attention_mask"][idx],
88
        )
89

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.