llm-adapters

Форк
0
/
commonsense_evaluate.py 
300 строк · 9.0 Кб
1
import copy
2
import json
3
import os
4
import re
5
import sys
6
import argparse
7

8
import fire
9

10
import torch
11

12
sys.path.append(os.path.join(os.getcwd(), "peft/src/"))
13
from peft import PeftModel
14
from tqdm import tqdm
15
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer
16

17
if torch.cuda.is_available():
18
    device = "cuda"
19
else:
20
    device = "cpu"
21

22
try:
23
    if torch.backends.mps.is_available():
24
        device = "mps"
25
except:  # noqa: E722
26
    pass
27

28

29
def main(
30
        load_8bit: bool = False,
31
        base_model: str = "",
32
        lora_weights: str = "tloen/alpaca-lora-7b",
33
        share_gradio: bool = False,
34
):
35
    args = parse_args()
36

37
    def evaluate(
38
            instructions,
39
            input=None,
40
            temperature=0.1,
41
            top_p=0.75,
42
            top_k=40,
43
            num_beams=4,
44
            max_new_tokens=32,
45
            **kwargs,
46
    ):
47
        prompts = [generate_prompt(instruction, input) for instruction in instructions]
48
        inputs = tokenizer(prompts, return_tensors="pt", padding=True)
49
        input_ids = inputs["input_ids"].to(device)
50
        generation_config = GenerationConfig(
51
            temperature=temperature,
52
            top_p=top_p,
53
            top_k=top_k,
54
            num_beams=num_beams,
55
            **kwargs,
56
        )
57
        with torch.no_grad():
58
            generation_output = model.generate(
59
                input_ids=input_ids,
60
                generation_config=generation_config,
61
                return_dict_in_generate=True,
62
                output_scores=True,
63
                max_new_tokens=max_new_tokens,
64
            )
65
        s = generation_output.sequences
66
        outputs = tokenizer.batch_decode(s, skip_special_tokens=True)
67
        outputs = [o.split("### Response:")[1].strip() for o in outputs]
68
        print(outputs)
69
        return outputs
70

71
    save_file = f'experiment/{args.model}-{args.adapter}-{args.dataset}.json'
72
    create_dir('experiment/')
73

74
    dataset = load_data(args)
75
    batches = create_batch(dataset, args.batch_size)
76
    tokenizer, model = load_model(args)
77
    total = len(batches)
78
    correct = 0
79
    current = 0
80
    output_data = []
81
    pbar = tqdm(total=total)
82
    for idx, batch in enumerate(batches):
83
        current += len(batch)
84
        instructions = [data.get('instruction') for data in batch]
85

86
        outputs = evaluate(instructions)
87

88
        for data, output in zip(batch, outputs):
89
            label = data.get('answer')
90
            flag = False
91
            predict = extract_answer(args, output)
92
            if label == predict:
93
                correct += 1
94
                flag = True
95
            new_data = copy.deepcopy(data)
96
            new_data['output_pred'] = output
97
            new_data['pred'] = predict
98
            new_data['flag'] = flag
99
            output_data.append(new_data)
100
            print(data["instruction"])
101
            print(output)
102
            print('prediction:', predict)
103
            print('label:', label)
104
        print('---------------')
105
        print(f'\rtest:{idx + 1}/{total} | accuracy {correct}  {correct / current}')
106
        print('---------------')
107
        with open(save_file, 'w+') as f:
108
            json.dump(output_data, f, indent=4)
109
        pbar.update(1)
110
    pbar.close()
111
    print('\n')
112
    print('test finished')
113

114

115
def create_dir(dir_path):
116
    if not os.path.exists(dir_path):
117
        os.mkdir(dir_path)
118
    return
119

120

121
def generate_prompt(instruction, input=None):
122
    if input:
123
        return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
124

125
                ### Instruction:
126
                {instruction}
127

128
                ### Input:
129
                {input}
130

131
                ### Response:
132
                """  # noqa: E501
133
    else:
134
        return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. 
135

136
                ### Instruction:
137
                {instruction}
138

139
                ### Response:
140
                """  # noqa: E501
141

142

143
def load_data(args) -> list:
144
    """
145
    read data from dataset file
146
    Args:
147
        args:
148

149
    Returns:
150

151
    """
152
    file_path = f'dataset/{args.dataset}/test.json'
153
    if not os.path.exists(file_path):
154
        raise FileNotFoundError(f"can not find dataset file : {file_path}")
155
    json_data = json.load(open(file_path, 'r'))
156
    return json_data
157

158
def create_batch(dataset, batch_size):
159
    batches = []
160
    num_batch = len(dataset)//batch_size if len(dataset) % batch_size == 0 else len(dataset)//batch_size + 1
161
    for i in range(num_batch):
162
        batch = dataset[i*batch_size: min((i+1)*batch_size, len(dataset))]
163
        batches.append(batch)
164
    return batches
165

166

167
def parse_args():
168
    parser = argparse.ArgumentParser()
169
    parser.add_argument('--dataset', choices=["boolq", "piqa", "social_i_qa", "hellaswag", "winogrande", "ARC-Challenge", "ARC-Easy", "openbookqa"],
170
                        required=True)
171
    parser.add_argument('--model', choices=['LLaMA-7B', "LLaMA-13B",'BLOOM-7B', 'GPT-j-6B'], required=True)
172
    parser.add_argument('--adapter', choices=['LoRA', 'AdapterP', 'AdapterH', 'Parallel'],
173
                        required=True)
174
    parser.add_argument('--base_model', required=True)
175
    parser.add_argument('--lora_weights', required=True)
176
    parser.add_argument('--batch_size', type=int, required=True)
177
    parser.add_argument('--load_8bit', action='store_true', default=False)
178

179
    return parser.parse_args()
180

181

182
def load_model(args) -> tuple:
183
    """
184
    load tuned model
185
    Args:
186
        args:
187

188
    Returns:
189
        tuple(tokenizer, model)
190
    """
191
    base_model = args.base_model
192
    if not base_model:
193
        raise ValueError(f'can not find base model name by the value: {args.model}')
194
    lora_weights = args.lora_weights
195
    if not lora_weights:
196
        raise ValueError(f'can not find lora weight, the value is: {lora_weights}')
197

198
    load_8bit = args.load_8bit
199
    if "LLaMA" in args.model:
200
        tokenizer = LlamaTokenizer.from_pretrained(base_model)
201
    else:
202
        tokenizer = AutoTokenizer.from_pretrained(base_model)
203
    tokenizer.padding_side = "left"
204
    tokenizer.pad_token_id = (
205
        0  # unk. we want this to be different from the eos token
206
    )
207
    if device == "cuda":
208
        model = AutoModelForCausalLM.from_pretrained(
209
            base_model,
210
            load_in_8bit=load_8bit,
211
            torch_dtype=torch.float16,
212
            device_map="auto",
213
            trust_remote_code=True,
214
        ) # fix zwq
215
        model = PeftModel.from_pretrained(
216
            model,
217
            lora_weights,
218
            torch_dtype=torch.float16,
219
            device_map={"":0}
220
        )
221
    elif device == "mps":
222
        model = AutoModelForCausalLM.from_pretrained(
223
            base_model,
224
            device_map={"": device},
225
            torch_dtype=torch.float16,
226
        )
227
        model = PeftModel.from_pretrained(
228
            model,
229
            lora_weights,
230
            device_map={"": device},
231
            torch_dtype=torch.float16,
232
        )
233
    else:
234
        model = AutoModelForCausalLM.from_pretrained(
235
            base_model, device_map={"": device}, low_cpu_mem_usage=True
236
        )
237
        model = PeftModel.from_pretrained(
238
            model,
239
            lora_weights,
240
            device_map={"": device},
241
        )
242

243
        # unwind broken decapoda-research config
244
        model.config.pad_token_id = tokenizer.pad_token_id = 0  # unk
245
        model.config.bos_token_id = 1
246
        model.config.eos_token_id = 2
247

248
        if not load_8bit:
249
            model.half()  # seems to fix bugs for some users.
250

251
        model.eval()
252
        if torch.__version__ >= "2" and sys.platform != "win32":
253
            model = torch.compile(model)
254

255
    return tokenizer, model
256

257

258
def load_instruction(args) -> str:
259
    instruction = ''
260
    if not instruction:
261
        raise ValueError('instruct not initialized')
262
    return instruction
263

264

265
def extract_answer(args, sentence: str) -> float:
266
    dataset = args.dataset
267
    if dataset == 'boolq':
268
        sentence_ = sentence.strip()
269
        pred_answers = re.findall(r'true|false', sentence_)
270
        if not pred_answers:
271
            return ""
272
        return pred_answers[0]
273
    elif dataset == 'piqa':
274
        sentence_ = sentence.strip()
275
        pred_answers = re.findall(r'solution1|solution2', sentence_)
276
        if not pred_answers:
277
            return ""
278
        return pred_answers[0]
279
    elif dataset in ['social_i_qa', 'ARC-Challenge', 'ARC-Easy', 'openbookqa']:
280
        sentence_ = sentence.strip()
281
        pred_answers = re.findall(r'answer1|answer2|answer3|answer4|answer5', sentence_)
282
        if not pred_answers:
283
            return ""
284
        return pred_answers[0]
285
    elif dataset == 'hellaswag':
286
        sentence_ = sentence.strip()
287
        pred_answers = re.findall(r'ending1|ending2|ending3|ending4', sentence_)
288
        if not pred_answers:
289
            return ""
290
        return pred_answers[0]
291
    elif dataset == 'winogrande':
292
        sentence_ = sentence.strip()
293
        pred_answers = re.findall(r'option1|option2', sentence_)
294
        if not pred_answers:
295
            return ""
296
        return pred_answers[0]
297

298

299
if __name__ == "__main__":
300
    main()
301

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.