colossalai
88 строк · 3.2 Кб
1from typing import Callable
2
3from torch.utils.data import Dataset
4from tqdm import tqdm
5
6from .utils import is_rank_0
7
8
9# Dahoas/rm-static
10class RmStaticDataset(Dataset):
11"""
12Dataset for reward model
13
14Args:
15dataset: dataset for reward model
16tokenizer: tokenizer for reward model
17max_length: max length of input
18special_token: special token at the end of sentence
19"""
20
21def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
22super().__init__()
23self.end_token = tokenizer.eos_token if special_token is None else special_token
24
25chosen = [data["prompt"] + data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
26chosen_token = tokenizer(
27chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
28)
29self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
30
31reject = [data["prompt"] + data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
32reject_token = tokenizer(
33reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
34)
35self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
36
37def __len__(self):
38length = self.chosen["input_ids"].shape[0]
39return length
40
41def __getitem__(self, idx):
42return (
43self.chosen["input_ids"][idx],
44self.chosen["attention_mask"][idx],
45self.reject["input_ids"][idx],
46self.reject["attention_mask"][idx],
47)
48
49
50# Anthropic/hh-rlhf
51class HhRlhfDataset(Dataset):
52"""
53Dataset for reward model
54
55Args:
56dataset: dataset for reward model
57tokenizer: tokenizer for reward model
58max_length: max length of input
59special_token: special token at the end of sentence
60"""
61
62def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
63super().__init__()
64self.end_token = tokenizer.eos_token if special_token is None else special_token
65
66chosen = [data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
67chosen_token = tokenizer(
68chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
69)
70self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
71
72reject = [data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
73reject_token = tokenizer(
74reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
75)
76self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
77
78def __len__(self):
79length = self.chosen["input_ids"].shape[0]
80return length
81
82def __getitem__(self, idx):
83return (
84self.chosen["input_ids"][idx],
85self.chosen["attention_mask"][idx],
86self.reject["input_ids"][idx],
87self.reject["attention_mask"][idx],
88)
89