colossalai
200 строк · 7.6 Кб
1# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import copy
16from typing import Dict, Optional, Sequence, Tuple
17
18import torch
19from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
20from torch.utils.data import Dataset
21from tqdm import tqdm
22from transformers import PreTrainedTokenizer
23
24from colossalai.logging import get_dist_logger
25
26from .utils import is_rank_0, jload
27
28logger = get_dist_logger()
29
30IGNORE_INDEX = -100
31PROMPT_DICT = {
32"prompt_input": (
33"Below is an instruction that describes a task, paired with an input that provides further context. "
34"Write a response that appropriately completes the request.\n\n"
35"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
36),
37"prompt_no_input": (
38"Below is an instruction that describes a task. "
39"Write a response that appropriately completes the request.\n\n"
40"### Instruction:\n{instruction}\n\n### Response:"
41),
42}
43
44
45def _preprocess(
46sources: Sequence[str],
47targets: Sequence[str],
48tokenizer: PreTrainedTokenizer,
49max_length: int,
50) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
51"""Preprocess the data by tokenizing."""
52sequences = [s + t + tokenizer.eos_token for s, t in zip(sources, targets)]
53sequences_token = tokenizer(
54sequences, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt", add_special_tokens=False
55)
56
57sources_token = tokenizer(
58sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt", add_special_tokens=False
59)
60
61assert sequences_token["attention_mask"].dim() == 2, "seq2seq model should be preprocessed differently"
62labels = copy.deepcopy(sequences_token["input_ids"])
63for i in range(labels.shape[0]):
64source_len = sources_token["attention_mask"][i].sum().item()
65pad_len = max_length - sequences_token["attention_mask"][i].sum().item()
66if tokenizer.padding_side == "right":
67# |prompt|completion|eos|pad|
68labels[i][:source_len] = IGNORE_INDEX
69if pad_len>0:
70labels[i][-pad_len:] = IGNORE_INDEX
71elif tokenizer.padding_side == "left":
72# |pad|prompt|completion|eos|
73labels[i][: pad_len + source_len] = IGNORE_INDEX
74else:
75raise RuntimeError()
76
77return sequences_token["input_ids"], labels, sequences_token["attention_mask"]
78
79
80def _preprocess_chatglm(
81sources: Sequence[str],
82targets: Sequence[str],
83tokenizer: PreTrainedTokenizer,
84max_length: int,
85) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
86"""
87Preprocess the data by tokenizing.
88None for attention mask, ChatGLM will calculate attention mask according to input ids
89"""
90
91labels = []
92input_ids = []
93for source, target in zip(sources, targets):
94source_id = tokenizer.encode(text=source, add_special_tokens=False)
95target_id = tokenizer.encode(text=target, add_special_tokens=False)
96input_id = tokenizer.build_inputs_with_special_tokens(source_id, target_id)
97# truncate
98sp_token_list = [tokenizer.gmask_token_id, tokenizer.bos_token_id]
99truncate_length = max(0, len(input_id) - max_length)
100input_id = input_id[truncate_length:]
101if truncate_length == len(source_id) + 1:
102input_id = sp_token_list + input_id[1:]
103elif truncate_length > len(source_id) + 1:
104input_id = sp_token_list + input_id[2:]
105
106context_length = input_id.index(tokenizer.bos_token_id)
107mask_position = context_length - 1
108label = [IGNORE_INDEX] * context_length + input_id[mask_position + 1 :]
109
110pad_len = max_length - len(input_id)
111input_id = input_id + [tokenizer.pad_token_id] * pad_len
112input_ids.append(input_id)
113labels.append(label + [IGNORE_INDEX] * pad_len)
114return torch.tensor(input_ids), torch.tensor(labels), None
115
116
117class SFTDataset(Dataset):
118"""
119Dataset for sft model
120
121Args:
122dataset: dataset for supervised model
123tokenizer: tokenizer for supervised model
124max_length: max length of input
125"""
126
127def __init__(self, dataset: Dict, tokenizer: PreTrainedTokenizer, max_length: int = 512) -> None:
128super().__init__()
129self.input_ids = []
130
131sources = [data["prompt"] for data in dataset]
132targets = [data["completion"] + tokenizer.eos_token for data in tqdm(dataset, disable=not is_rank_0())]
133
134logger.info("Tokenizing inputs... This may take some time...")
135if isinstance(tokenizer, ChatGLMTokenizer):
136self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm(
137sources, targets, tokenizer, max_length
138)
139else:
140self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
141
142logger.info("Loaded dataset.")
143
144def __len__(self):
145length = self.input_ids.shape[0]
146return length
147
148def __getitem__(self, idx):
149if self.attention_mask is not None:
150return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
151else:
152return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
153
154
155class SupervisedDataset(Dataset):
156"""Dataset for supervised fine-tuning."""
157
158def __init__(
159self,
160data_path: str,
161tokenizer: PreTrainedTokenizer,
162max_datasets_size: Optional[int] = None,
163max_length: int = 512,
164):
165super().__init__()
166logger.info("Loading data...")
167list_data_dict = jload(data_path)
168logger.info(f"Loaded {len(list_data_dict)} examples.")
169
170if max_datasets_size is not None:
171logger.info(f"Limiting dataset to {max_datasets_size} examples.")
172list_data_dict = list_data_dict[:max_datasets_size]
173
174logger.info("Formatting inputs...")
175prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
176sources = [
177prompt_input.format_map(example) if "input" in example else prompt_no_input.format_map(example)
178for example in list_data_dict
179]
180targets = [example["output"] + tokenizer.eos_token for example in list_data_dict]
181
182logger.info("Tokenizing inputs... This may take some time...")
183if isinstance(tokenizer, ChatGLMTokenizer):
184self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm(
185sources, targets, tokenizer, max_length
186)
187else:
188self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
189
190logger.info("Loaded dataset.")
191
192def __len__(self):
193length = self.input_ids.shape[0]
194return length
195
196def __getitem__(self, idx):
197if self.attention_mask is not None:
198return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
199else:
200return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
201