llm-adapters
/
commonsense_evaluate.py
300 строк · 9.0 Кб
1import copy2import json3import os4import re5import sys6import argparse7
8import fire9
10import torch11
12sys.path.append(os.path.join(os.getcwd(), "peft/src/"))13from peft import PeftModel14from tqdm import tqdm15from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer16
17if torch.cuda.is_available():18device = "cuda"19else:20device = "cpu"21
22try:23if torch.backends.mps.is_available():24device = "mps"25except: # noqa: E72226pass27
28
29def main(30load_8bit: bool = False,31base_model: str = "",32lora_weights: str = "tloen/alpaca-lora-7b",33share_gradio: bool = False,34):35args = parse_args()36
37def evaluate(38instructions,39input=None,40temperature=0.1,41top_p=0.75,42top_k=40,43num_beams=4,44max_new_tokens=32,45**kwargs,46):47prompts = [generate_prompt(instruction, input) for instruction in instructions]48inputs = tokenizer(prompts, return_tensors="pt", padding=True)49input_ids = inputs["input_ids"].to(device)50generation_config = GenerationConfig(51temperature=temperature,52top_p=top_p,53top_k=top_k,54num_beams=num_beams,55**kwargs,56)57with torch.no_grad():58generation_output = model.generate(59input_ids=input_ids,60generation_config=generation_config,61return_dict_in_generate=True,62output_scores=True,63max_new_tokens=max_new_tokens,64)65s = generation_output.sequences66outputs = tokenizer.batch_decode(s, skip_special_tokens=True)67outputs = [o.split("### Response:")[1].strip() for o in outputs]68print(outputs)69return outputs70
71save_file = f'experiment/{args.model}-{args.adapter}-{args.dataset}.json'72create_dir('experiment/')73
74dataset = load_data(args)75batches = create_batch(dataset, args.batch_size)76tokenizer, model = load_model(args)77total = len(batches)78correct = 079current = 080output_data = []81pbar = tqdm(total=total)82for idx, batch in enumerate(batches):83current += len(batch)84instructions = [data.get('instruction') for data in batch]85
86outputs = evaluate(instructions)87
88for data, output in zip(batch, outputs):89label = data.get('answer')90flag = False91predict = extract_answer(args, output)92if label == predict:93correct += 194flag = True95new_data = copy.deepcopy(data)96new_data['output_pred'] = output97new_data['pred'] = predict98new_data['flag'] = flag99output_data.append(new_data)100print(data["instruction"])101print(output)102print('prediction:', predict)103print('label:', label)104print('---------------')105print(f'\rtest:{idx + 1}/{total} | accuracy {correct} {correct / current}')106print('---------------')107with open(save_file, 'w+') as f:108json.dump(output_data, f, indent=4)109pbar.update(1)110pbar.close()111print('\n')112print('test finished')113
114
115def create_dir(dir_path):116if not os.path.exists(dir_path):117os.mkdir(dir_path)118return119
120
121def generate_prompt(instruction, input=None):122if input:123return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.124
125### Instruction:
126{instruction}127
128### Input:
129{input}130
131### Response:
132""" # noqa: E501133else:134return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.135
136### Instruction:
137{instruction}138
139### Response:
140""" # noqa: E501141
142
143def load_data(args) -> list:144"""145read data from dataset file
146Args:
147args:
148
149Returns:
150
151"""
152file_path = f'dataset/{args.dataset}/test.json'153if not os.path.exists(file_path):154raise FileNotFoundError(f"can not find dataset file : {file_path}")155json_data = json.load(open(file_path, 'r'))156return json_data157
158def create_batch(dataset, batch_size):159batches = []160num_batch = len(dataset)//batch_size if len(dataset) % batch_size == 0 else len(dataset)//batch_size + 1161for i in range(num_batch):162batch = dataset[i*batch_size: min((i+1)*batch_size, len(dataset))]163batches.append(batch)164return batches165
166
167def parse_args():168parser = argparse.ArgumentParser()169parser.add_argument('--dataset', choices=["boolq", "piqa", "social_i_qa", "hellaswag", "winogrande", "ARC-Challenge", "ARC-Easy", "openbookqa"],170required=True)171parser.add_argument('--model', choices=['LLaMA-7B', "LLaMA-13B",'BLOOM-7B', 'GPT-j-6B'], required=True)172parser.add_argument('--adapter', choices=['LoRA', 'AdapterP', 'AdapterH', 'Parallel'],173required=True)174parser.add_argument('--base_model', required=True)175parser.add_argument('--lora_weights', required=True)176parser.add_argument('--batch_size', type=int, required=True)177parser.add_argument('--load_8bit', action='store_true', default=False)178
179return parser.parse_args()180
181
182def load_model(args) -> tuple:183"""184load tuned model
185Args:
186args:
187
188Returns:
189tuple(tokenizer, model)
190"""
191base_model = args.base_model192if not base_model:193raise ValueError(f'can not find base model name by the value: {args.model}')194lora_weights = args.lora_weights195if not lora_weights:196raise ValueError(f'can not find lora weight, the value is: {lora_weights}')197
198load_8bit = args.load_8bit199if "LLaMA" in args.model:200tokenizer = LlamaTokenizer.from_pretrained(base_model)201else:202tokenizer = AutoTokenizer.from_pretrained(base_model)203tokenizer.padding_side = "left"204tokenizer.pad_token_id = (2050 # unk. we want this to be different from the eos token206)207if device == "cuda":208model = AutoModelForCausalLM.from_pretrained(209base_model,210load_in_8bit=load_8bit,211torch_dtype=torch.float16,212device_map="auto",213trust_remote_code=True,214) # fix zwq215model = PeftModel.from_pretrained(216model,217lora_weights,218torch_dtype=torch.float16,219device_map={"":0}220)221elif device == "mps":222model = AutoModelForCausalLM.from_pretrained(223base_model,224device_map={"": device},225torch_dtype=torch.float16,226)227model = PeftModel.from_pretrained(228model,229lora_weights,230device_map={"": device},231torch_dtype=torch.float16,232)233else:234model = AutoModelForCausalLM.from_pretrained(235base_model, device_map={"": device}, low_cpu_mem_usage=True236)237model = PeftModel.from_pretrained(238model,239lora_weights,240device_map={"": device},241)242
243# unwind broken decapoda-research config244model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk245model.config.bos_token_id = 1246model.config.eos_token_id = 2247
248if not load_8bit:249model.half() # seems to fix bugs for some users.250
251model.eval()252if torch.__version__ >= "2" and sys.platform != "win32":253model = torch.compile(model)254
255return tokenizer, model256
257
258def load_instruction(args) -> str:259instruction = ''260if not instruction:261raise ValueError('instruct not initialized')262return instruction263
264
265def extract_answer(args, sentence: str) -> float:266dataset = args.dataset267if dataset == 'boolq':268sentence_ = sentence.strip()269pred_answers = re.findall(r'true|false', sentence_)270if not pred_answers:271return ""272return pred_answers[0]273elif dataset == 'piqa':274sentence_ = sentence.strip()275pred_answers = re.findall(r'solution1|solution2', sentence_)276if not pred_answers:277return ""278return pred_answers[0]279elif dataset in ['social_i_qa', 'ARC-Challenge', 'ARC-Easy', 'openbookqa']:280sentence_ = sentence.strip()281pred_answers = re.findall(r'answer1|answer2|answer3|answer4|answer5', sentence_)282if not pred_answers:283return ""284return pred_answers[0]285elif dataset == 'hellaswag':286sentence_ = sentence.strip()287pred_answers = re.findall(r'ending1|ending2|ending3|ending4', sentence_)288if not pred_answers:289return ""290return pred_answers[0]291elif dataset == 'winogrande':292sentence_ = sentence.strip()293pred_answers = re.findall(r'option1|option2', sentence_)294if not pred_answers:295return ""296return pred_answers[0]297
298
299if __name__ == "__main__":300main()301