colossalai
45 строк · 1.5 Кб
1from collections import defaultdict2from typing import Dict3
4import torch5import transformers6from torch.utils.data import Dataset7
8from colossalai.logging import get_dist_logger9
10from .utils import jload11
12
13class PromptDataset(Dataset):14"""Dataset for supervised fine-tuning."""15
16def __init__(17self,18data_path: str,19tokenizer: transformers.PreTrainedTokenizer,20max_datasets_size: int = None,21max_length: int = 96,22):23super(PromptDataset, self).__init__()24self.keyed_prompt = defaultdict(list)25self.logger = get_dist_logger()26self.logger.info("Loading data...")27list_data_dict = jload(data_path)28self.logger.info(f"Loaded {len(list_data_dict)} examples.")29
30if max_datasets_size is not None:31self.logger.info(f"Limiting dataset to {max_datasets_size} examples.")32list_data_dict = list_data_dict[:max_datasets_size]33
34instructions = [data_dict["instruction"] for data_dict in list_data_dict]35tokens = tokenizer(36instructions, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True37)38for k, tensor in tokens.items():39self.keyed_prompt[k] = tensor.to(torch.cuda.current_device()).unbind()40
41def __len__(self):42return len(self.keyed_prompt["input_ids"])43
44def __getitem__(self, i) -> Dict[str, torch.Tensor]:45return {k: v[i] for k, v in self.keyed_prompt.items()}46