llm-adapters
/
finetune.py
347 строк · 12.7 Кб
1import os2import sys3from typing import List4
5import fire6import torch7import transformers8from datasets import load_dataset9from typing import List, Optional, Union10
11"""
12Unused imports:
13import torch.nn as nn
14import bitsandbytes as bnb
15"""
16sys.path.append(os.path.join(os.getcwd(), "peft/src/"))17from peft import ( # noqa: E40218LoraConfig,19BottleneckConfig,20PrefixTuningConfig,21get_peft_model,22get_peft_model_state_dict,23prepare_model_for_int8_training,24set_peft_model_state_dict,25)
26from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, AutoModel # noqa: F40227
28
29def train(30# model/data params31base_model: str = "", # the only required argument32data_path: str = "yahma/alpaca-cleaned",33output_dir: str = "./lora-alpaca",34adapter_name: str = "lora",35load_8bit : bool = False,36# training hyperparams37batch_size: int = 128,38micro_batch_size: int = 4,39num_epochs: int = 3,40learning_rate: float = 3e-4,41cutoff_len: int = 256,42val_set_size: int = 2000,43use_gradient_checkpointing: bool = False,44eval_step: int = 200,45save_step: int = 200,46# lora hyperparams47lora_r: int = 8,48lora_alpha: int = 16,49lora_dropout: float = 0.05,50lora_target_modules: List[str] = None,51# bottleneck adapter hyperparams52bottleneck_size: int = 256,53non_linearity: str = "tanh",54adapter_dropout: float = 0.0,55use_parallel_adapter: bool = False,56use_adapterp: bool = False,57target_modules: List[str] = None,58scaling: Union[float, str] = 1.0,59# prefix tuning hyperparams60num_virtual_tokens: int = 30,61# llm hyperparams62train_on_inputs: bool = True, # if False, masks out inputs in loss63group_by_length: bool = False, # faster, but produces an odd training loss curve64# wandb params65wandb_project: str = "",66wandb_run_name: str = "",67wandb_watch: str = "", # options: false | gradients | all68wandb_log_model: str = "", # options: false | true69resume_from_checkpoint: str = None, # either training checkpoint or final adapter70):71print(72f"Finetuning model with params:\n"73f"base_model: {base_model}\n"74f"data_path: {data_path}\n"75f"output_dir: {output_dir}\n"76f"batch_size: {batch_size}\n"77f"micro_batch_size: {micro_batch_size}\n"78f"num_epochs: {num_epochs}\n"79f"learning_rate: {learning_rate}\n"80f"cutoff_len: {cutoff_len}\n"81f"val_set_size: {val_set_size}\n"82f"use_gradient_checkpointing: {use_gradient_checkpointing}\n"83f"lora_r: {lora_r}\n"84f"lora_alpha: {lora_alpha}\n"85f"lora_dropout: {lora_dropout}\n"86f"lora_target_modules: {lora_target_modules}\n"87f"bottleneck_size: {bottleneck_size}\n"88f"non_linearity: {non_linearity}\n"89f"adapter_dropout: {adapter_dropout}\n"90f"use_parallel_adapter: {use_parallel_adapter}\n"91f"use_adapterp: {use_adapterp}\n"92f"train_on_inputs: {train_on_inputs}\n"93f"scaling: {scaling}\n"94f"adapter_name: {adapter_name}\n"95f"target_modules: {target_modules}\n"96f"group_by_length: {group_by_length}\n"97f"wandb_project: {wandb_project}\n"98f"wandb_run_name: {wandb_run_name}\n"99f"wandb_watch: {wandb_watch}\n"100f"wandb_log_model: {wandb_log_model}\n"101f"resume_from_checkpoint: {resume_from_checkpoint}\n"102)103assert (104base_model
105), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"106gradient_accumulation_steps = batch_size // micro_batch_size107
108device_map = "auto"109world_size = int(os.environ.get("WORLD_SIZE", 1))110ddp = world_size != 1111if ddp:112device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}113gradient_accumulation_steps = gradient_accumulation_steps // world_size114
115# Check if parameter passed or if set within environ116use_wandb = len(wandb_project) > 0 or (117"WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0118)119# Only overwrite environ if wandb param passed120if len(wandb_project) > 0:121os.environ["WANDB_PROJECT"] = wandb_project122if len(wandb_watch) > 0:123os.environ["WANDB_WATCH"] = wandb_watch124if len(wandb_log_model) > 0:125os.environ["WANDB_LOG_MODEL"] = wandb_log_model126
127if load_8bit:128model = AutoModelForCausalLM.from_pretrained(129base_model,130load_in_8bit=load_8bit,131torch_dtype=torch.float16,132device_map=device_map,133trust_remote_code=True,134)135else:136model = AutoModelForCausalLM.from_pretrained(137base_model,138load_in_8bit=False,139torch_dtype=torch.float16,140device_map={"": int(os.environ.get("LOCAL_RANK") or 0)},141trust_remote_code=True,142)143
144if model.config.model_type == "llama":145# Due to the name of transformers' LlamaTokenizer, we have to do this146tokenizer = LlamaTokenizer.from_pretrained(base_model)147else:148tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)149
150tokenizer.pad_token_id = (1510 # unk. we want this to be different from the eos token152)153tokenizer.padding_side = "left" # Allow batched inference154
155def tokenize(prompt, add_eos_token=True):156# there's probably a way to do this with the tokenizer settings157# but again, gotta move fast158result = tokenizer(159prompt,160truncation=True,161max_length=cutoff_len,162padding=False,163return_tensors=None,164)165if (166result["input_ids"][-1] != tokenizer.eos_token_id167and len(result["input_ids"]) < cutoff_len168and add_eos_token169):170result["input_ids"].append(tokenizer.eos_token_id)171if "chatglm" not in base_model:172result["attention_mask"].append(1)173
174result["labels"] = result["input_ids"].copy()175
176if "chatglm" in base_model:177return {"input_ids": result["input_ids"], "labels": result["labels"]}178else:179return result180
181def generate_and_tokenize_prompt(data_point):182full_prompt = generate_prompt(data_point)183tokenized_full_prompt = tokenize(full_prompt)184if not train_on_inputs:185user_prompt = generate_prompt({**data_point, "output": ""})186tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)187user_prompt_len = len(tokenized_user_prompt["input_ids"])188
189tokenized_full_prompt["labels"] = [190-100191] * user_prompt_len + tokenized_full_prompt["labels"][192user_prompt_len:193] # could be sped up, probably194return tokenized_full_prompt195
196model = prepare_model_for_int8_training(model, use_gradient_checkpointing=use_gradient_checkpointing)197if adapter_name == "lora":198config = LoraConfig(199r=lora_r,200lora_alpha=lora_alpha,201target_modules=target_modules,202lora_dropout=lora_dropout,203bias="none",204task_type="CAUSAL_LM",205)206elif adapter_name == "bottleneck":207config = BottleneckConfig(208bottleneck_size=bottleneck_size,209non_linearity=non_linearity,210adapter_dropout=adapter_dropout,211use_parallel_adapter=use_parallel_adapter,212use_adapterp=use_adapterp,213target_modules=target_modules,214scaling=scaling,215bias="none",216task_type="CAUSAL_LM",217)218elif adapter_name == "prefix-tuning":219config = PrefixTuningConfig(220num_virtual_tokens=num_virtual_tokens,221task_type="CAUSAL_LM",222)223model = get_peft_model(model, config)224if adapter_name == "prefix-tuning":225model.to('cuda')226
227if data_path.endswith(".json"): # todo: support jsonl228data = load_dataset("json", data_files=data_path)229else:230data = load_dataset(data_path)231
232if resume_from_checkpoint:233# Check the available weights and load them234checkpoint_name = os.path.join(235resume_from_checkpoint, "pytorch_model.bin"236) # Full checkpoint237if not os.path.exists(checkpoint_name):238checkpoint_name = os.path.join(239resume_from_checkpoint, "adapter_model.bin"240) # only LoRA model - LoRA config above has to fit241resume_from_checkpoint = (242False # So the trainer won't try loading its state243)244# The two files above have a different name depending on how they were saved, but are actually the same.245if os.path.exists(checkpoint_name):246print(f"Restarting from {checkpoint_name}")247adapters_weights = torch.load(checkpoint_name)248model = set_peft_model_state_dict(model, adapters_weights)249else:250print(f"Checkpoint {checkpoint_name} not found")251
252model.print_trainable_parameters() # Be more transparent about the % of trainable params.253
254if val_set_size > 0:255train_val = data["train"].train_test_split(256test_size=val_set_size, shuffle=True, seed=42257)258train_data = (259train_val["train"].shuffle().map(generate_and_tokenize_prompt)260)261val_data = (262train_val["test"].shuffle().map(generate_and_tokenize_prompt)263)264else:265train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)266val_data = None267
268if not ddp and torch.cuda.device_count() > 1:269# keeps Trainer from trying its own DataParallelism when more than 1 gpu is available270model.is_parallelizable = True271model.model_parallel = True272
273trainer = transformers.Trainer(274model=model,275train_dataset=train_data,276eval_dataset=val_data,277args=transformers.TrainingArguments(278per_device_train_batch_size=micro_batch_size,279gradient_accumulation_steps=gradient_accumulation_steps,280warmup_steps=100,281num_train_epochs=num_epochs,282learning_rate=learning_rate,283fp16=True,284logging_steps=10,285optim="adamw_torch",286evaluation_strategy="steps" if val_set_size > 0 else "no",287save_strategy="steps",288eval_steps=eval_step if val_set_size > 0 else None,289save_steps=save_step,290output_dir=output_dir,291save_total_limit=3,292load_best_model_at_end=True if val_set_size > 0 else False,293ddp_find_unused_parameters=False if ddp else None,294group_by_length=group_by_length,295report_to="wandb" if use_wandb else None,296run_name=wandb_run_name if use_wandb else None,297),298data_collator=transformers.DataCollatorForSeq2Seq(299tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True300),301)302model.config.use_cache = False303
304old_state_dict = model.state_dict305model.state_dict = (306lambda self, *_, **__: get_peft_model_state_dict(307self, old_state_dict()308)309).__get__(model, type(model))310
311if torch.__version__ >= "2" and sys.platform != "win32":312model = torch.compile(model)313
314trainer.train(resume_from_checkpoint=resume_from_checkpoint)315
316model.save_pretrained(output_dir)317
318print(319"\n If there's a warning about missing keys above, please disregard :)"320)321
322
323def generate_prompt(data_point):324# sorry about the formatting disaster gotta move fast325if data_point["input"]:326return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.327
328### Instruction:
329{data_point["instruction"]}330
331### Input:
332{data_point["input"]}333
334### Response:
335{data_point["output"]}""" # noqa: E501336else:337return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.338
339### Instruction:
340{data_point["instruction"]}341
342### Response:
343{data_point["output"]}""" # noqa: E501344
345
346if __name__ == "__main__":347fire.Fire(train)348