colossalai

Форк
0
45 строк · 1.5 Кб
1
from collections import defaultdict
2
from typing import Dict
3

4
import torch
5
import transformers
6
from torch.utils.data import Dataset
7

8
from colossalai.logging import get_dist_logger
9

10
from .utils import jload
11

12

13
class PromptDataset(Dataset):
14
    """Dataset for supervised fine-tuning."""
15

16
    def __init__(
17
        self,
18
        data_path: str,
19
        tokenizer: transformers.PreTrainedTokenizer,
20
        max_datasets_size: int = None,
21
        max_length: int = 96,
22
    ):
23
        super(PromptDataset, self).__init__()
24
        self.keyed_prompt = defaultdict(list)
25
        self.logger = get_dist_logger()
26
        self.logger.info("Loading data...")
27
        list_data_dict = jload(data_path)
28
        self.logger.info(f"Loaded {len(list_data_dict)} examples.")
29

30
        if max_datasets_size is not None:
31
            self.logger.info(f"Limiting dataset to {max_datasets_size} examples.")
32
            list_data_dict = list_data_dict[:max_datasets_size]
33

34
        instructions = [data_dict["instruction"] for data_dict in list_data_dict]
35
        tokens = tokenizer(
36
            instructions, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True
37
        )
38
        for k, tensor in tokens.items():
39
            self.keyed_prompt[k] = tensor.to(torch.cuda.current_device()).unbind()
40

41
    def __len__(self):
42
        return len(self.keyed_prompt["input_ids"])
43

44
    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
45
        return {k: v[i] for k, v in self.keyed_prompt.items()}
46

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.