colossalai

Форк
0
200 строк · 7.6 Кб
1
#    Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
2
#
3
#    Licensed under the Apache License, Version 2.0 (the "License");
4
#    you may not use this file except in compliance with the License.
5
#    You may obtain a copy of the License at
6
#
7
#        http://www.apache.org/licenses/LICENSE-2.0
8
#
9
#    Unless required by applicable law or agreed to in writing, software
10
#    distributed under the License is distributed on an "AS IS" BASIS,
11
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
#    See the License for the specific language governing permissions and
13
#    limitations under the License.
14

15
import copy
16
from typing import Dict, Optional, Sequence, Tuple
17

18
import torch
19
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
20
from torch.utils.data import Dataset
21
from tqdm import tqdm
22
from transformers import PreTrainedTokenizer
23

24
from colossalai.logging import get_dist_logger
25

26
from .utils import is_rank_0, jload
27

28
logger = get_dist_logger()
29

30
IGNORE_INDEX = -100
31
PROMPT_DICT = {
32
    "prompt_input": (
33
        "Below is an instruction that describes a task, paired with an input that provides further context. "
34
        "Write a response that appropriately completes the request.\n\n"
35
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
36
    ),
37
    "prompt_no_input": (
38
        "Below is an instruction that describes a task. "
39
        "Write a response that appropriately completes the request.\n\n"
40
        "### Instruction:\n{instruction}\n\n### Response:"
41
    ),
42
}
43

44

45
def _preprocess(
46
    sources: Sequence[str],
47
    targets: Sequence[str],
48
    tokenizer: PreTrainedTokenizer,
49
    max_length: int,
50
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
51
    """Preprocess the data by tokenizing."""
52
    sequences = [s + t + tokenizer.eos_token for s, t in zip(sources, targets)]
53
    sequences_token = tokenizer(
54
        sequences, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt", add_special_tokens=False
55
    )
56

57
    sources_token = tokenizer(
58
        sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt", add_special_tokens=False
59
    )
60

61
    assert sequences_token["attention_mask"].dim() == 2, "seq2seq model should be preprocessed differently"
62
    labels = copy.deepcopy(sequences_token["input_ids"])
63
    for i in range(labels.shape[0]):
64
        source_len = sources_token["attention_mask"][i].sum().item()
65
        pad_len = max_length - sequences_token["attention_mask"][i].sum().item()
66
        if tokenizer.padding_side == "right":
67
            # |prompt|completion|eos|pad|
68
            labels[i][:source_len] = IGNORE_INDEX
69
            if pad_len>0:
70
                labels[i][-pad_len:] = IGNORE_INDEX
71
        elif tokenizer.padding_side == "left":
72
            # |pad|prompt|completion|eos|
73
            labels[i][: pad_len + source_len] = IGNORE_INDEX
74
        else:
75
            raise RuntimeError()
76

77
    return sequences_token["input_ids"], labels, sequences_token["attention_mask"]
78

79

80
def _preprocess_chatglm(
81
    sources: Sequence[str],
82
    targets: Sequence[str],
83
    tokenizer: PreTrainedTokenizer,
84
    max_length: int,
85
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
86
    """
87
    Preprocess the data by tokenizing.
88
    None for attention mask, ChatGLM will calculate attention mask according to input ids
89
    """
90

91
    labels = []
92
    input_ids = []
93
    for source, target in zip(sources, targets):
94
        source_id = tokenizer.encode(text=source, add_special_tokens=False)
95
        target_id = tokenizer.encode(text=target, add_special_tokens=False)
96
        input_id = tokenizer.build_inputs_with_special_tokens(source_id, target_id)
97
        # truncate
98
        sp_token_list = [tokenizer.gmask_token_id, tokenizer.bos_token_id]
99
        truncate_length = max(0, len(input_id) - max_length)
100
        input_id = input_id[truncate_length:]
101
        if truncate_length == len(source_id) + 1:
102
            input_id = sp_token_list + input_id[1:]
103
        elif truncate_length > len(source_id) + 1:
104
            input_id = sp_token_list + input_id[2:]
105

106
        context_length = input_id.index(tokenizer.bos_token_id)
107
        mask_position = context_length - 1
108
        label = [IGNORE_INDEX] * context_length + input_id[mask_position + 1 :]
109

110
        pad_len = max_length - len(input_id)
111
        input_id = input_id + [tokenizer.pad_token_id] * pad_len
112
        input_ids.append(input_id)
113
        labels.append(label + [IGNORE_INDEX] * pad_len)
114
    return torch.tensor(input_ids), torch.tensor(labels), None
115

116

117
class SFTDataset(Dataset):
118
    """
119
    Dataset for sft model
120

121
    Args:
122
        dataset: dataset for supervised model
123
        tokenizer: tokenizer for supervised model
124
        max_length: max length of input
125
    """
126

127
    def __init__(self, dataset: Dict, tokenizer: PreTrainedTokenizer, max_length: int = 512) -> None:
128
        super().__init__()
129
        self.input_ids = []
130

131
        sources = [data["prompt"] for data in dataset]
132
        targets = [data["completion"] + tokenizer.eos_token for data in tqdm(dataset, disable=not is_rank_0())]
133

134
        logger.info("Tokenizing inputs... This may take some time...")
135
        if isinstance(tokenizer, ChatGLMTokenizer):
136
            self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm(
137
                sources, targets, tokenizer, max_length
138
            )
139
        else:
140
            self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
141

142
        logger.info("Loaded dataset.")
143

144
    def __len__(self):
145
        length = self.input_ids.shape[0]
146
        return length
147

148
    def __getitem__(self, idx):
149
        if self.attention_mask is not None:
150
            return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
151
        else:
152
            return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
153

154

155
class SupervisedDataset(Dataset):
156
    """Dataset for supervised fine-tuning."""
157

158
    def __init__(
159
        self,
160
        data_path: str,
161
        tokenizer: PreTrainedTokenizer,
162
        max_datasets_size: Optional[int] = None,
163
        max_length: int = 512,
164
    ):
165
        super().__init__()
166
        logger.info("Loading data...")
167
        list_data_dict = jload(data_path)
168
        logger.info(f"Loaded {len(list_data_dict)} examples.")
169

170
        if max_datasets_size is not None:
171
            logger.info(f"Limiting dataset to {max_datasets_size} examples.")
172
            list_data_dict = list_data_dict[:max_datasets_size]
173

174
        logger.info("Formatting inputs...")
175
        prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
176
        sources = [
177
            prompt_input.format_map(example) if "input" in example else prompt_no_input.format_map(example)
178
            for example in list_data_dict
179
        ]
180
        targets = [example["output"] + tokenizer.eos_token for example in list_data_dict]
181

182
        logger.info("Tokenizing inputs... This may take some time...")
183
        if isinstance(tokenizer, ChatGLMTokenizer):
184
            self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm(
185
                sources, targets, tokenizer, max_length
186
            )
187
        else:
188
            self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
189

190
        logger.info("Loaded dataset.")
191

192
    def __len__(self):
193
        length = self.input_ids.shape[0]
194
        return length
195

196
    def __getitem__(self, idx):
197
        if self.attention_mask is not None:
198
            return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
199
        else:
200
            return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
201

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.