LLM-FineTuning-Large-Language-Models
/
Mistral_FineTuning_with_PEFT_and_QLORA.ipynb
1 строка · 17.8 Кб
1{"cells":[{"cell_type":"markdown","metadata":{},"source":["### Checkout my [Twitter(@rohanpaul_ai)](https://twitter.com/rohanpaul_ai) for daily LLM bits"]},{"cell_type":"markdown","metadata":{},"source":["## Fine Tuning Mistral-7B with PEFT and QLoRA\n","\n","# [Link to my Youtube Video Explaining this whole Notebook](https://www.youtube.com/watch?v=6DGYj1EEWOw&list=PLxqBkZuBynVTzqUQCQFgetR97y1X_1uCI&index=13&ab_channel=Rohan-Paul-AI)\n","\n","[![Imgur](https://imgur.com/MJ22PMV.png)](https://www.youtube.com/watch?v=6DGYj1EEWOw&list=PLxqBkZuBynVTzqUQCQFgetR97y1X_1uCI&index=13&ab_channel=Rohan-Paul-AI)"]},{"cell_type":"markdown","metadata":{},"source":["![](assets/2023-11-25-00-08-01.png)"]},{"cell_type":"code","execution_count":1,"metadata":{"execution":{"iopub.execute_input":"2023-11-22T19:57:36.081861Z","iopub.status.busy":"2023-11-22T19:57:36.081116Z","iopub.status.idle":"2023-11-22T19:57:36.090426Z","shell.execute_reply":"2023-11-22T19:57:36.089534Z","shell.execute_reply.started":"2023-11-22T19:57:36.081827Z"},"trusted":true},"outputs":[],"source":["!pip install --upgrade peft accelerate bitsandbytes datasets trl"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2023-11-22T19:57:36.101514Z","iopub.status.busy":"2023-11-22T19:57:36.101273Z","iopub.status.idle":"2023-11-22T19:57:52.418945Z","shell.execute_reply":"2023-11-22T19:57:52.418028Z","shell.execute_reply.started":"2023-11-22T19:57:36.101492Z"},"trusted":true},"outputs":[],"source":["import os\n","from dataclasses import dataclass, field\n","from typing import Optional\n","from datasets.arrow_dataset import Dataset\n","import torch\n","from datasets import load_dataset\n","from peft import LoraConfig\n","from peft import AutoPeftModelForCausalLM\n","from transformers import (\n"," AutoModelForCausalLM,\n"," AutoTokenizer,\n"," BitsAndBytesConfig,\n"," HfArgumentParser,\n"," AutoTokenizer,\n"," TrainingArguments,\n",")\n","\n","from trl import SFTTrainer\n","\n","torch.manual_seed(42)"]},{"cell_type":"code","execution_count":3,"metadata":{"execution":{"iopub.execute_input":"2023-11-22T19:57:52.420978Z","iopub.status.busy":"2023-11-22T19:57:52.420355Z","iopub.status.idle":"2023-11-22T19:57:52.432297Z","shell.execute_reply":"2023-11-22T19:57:52.431294Z","shell.execute_reply.started":"2023-11-22T19:57:52.420944Z"},"trusted":true},"outputs":[],"source":["@dataclass\n","class ScriptArguments:\n"," \"\"\"\n"," These arguments vary depending on how many GPUs you have, what their capacity and features are, and what size model you want to train.\n"," \"\"\"\n"," local_rank: Optional[int] = -1\n"," per_device_train_batch_size: Optional[int] = 4\n"," per_device_eval_batch_size: Optional[int] = 4\n"," gradient_accumulation_steps: Optional[int] = 4\n"," learning_rate: Optional[float] = 2e-5\n"," max_grad_norm: Optional[float] = 0.3\n"," weight_decay: Optional[int] = 0.01\n"," lora_alpha: Optional[int] = 16\n"," lora_dropout: Optional[float] = 0.1\n"," lora_r: Optional[int] = 32\n"," max_seq_length: Optional[int] = 512\n"," # model_name: Optional[str] = \"bn22/Mistral-7B-Instruct-v0.1-sharded\"\n"," model_name: Optional[str] = \"mistralai/Mistral-7B-Instruct-v0.1\"\n"," dataset_name: Optional[str] = \"iamtarun/python_code_instructions_18k_alpaca\"\n"," use_4bit: Optional[bool] = True\n"," use_nested_quant: Optional[bool] = False\n"," bnb_4bit_compute_dtype: Optional[str] = \"float16\"\n"," bnb_4bit_quant_type: Optional[str] = \"nf4\"\n"," num_train_epochs: Optional[int] = 100\n"," fp16: Optional[bool] = False\n"," bf16: Optional[bool] = True\n"," packing: Optional[bool] = False\n"," gradient_checkpointing: Optional[bool] = True\n"," optim: Optional[str] = \"paged_adamw_32bit\"\n"," lr_scheduler_type: str = \"constant\"\n"," max_steps: int = 1000000\n"," warmup_ratio: float = 0.03\n"," group_by_length: bool = True\n"," save_steps: int = 50\n"," logging_steps: int = 50\n"," merge_and_push: Optional[bool] = False\n"," output_dir: str = \"./results_packing\""]},{"cell_type":"code","execution_count":4,"metadata":{"execution":{"iopub.execute_input":"2023-11-22T19:57:52.435734Z","iopub.status.busy":"2023-11-22T19:57:52.435359Z","iopub.status.idle":"2023-11-22T19:57:52.453642Z","shell.execute_reply":"2023-11-22T19:57:52.452833Z","shell.execute_reply.started":"2023-11-22T19:57:52.435684Z"},"trusted":true},"outputs":[],"source":["# parser = HfArgumentParser(ScriptArguments)\n","# script_args = parser.parse_args_into_dataclasses()[0]"]},{"cell_type":"code","execution_count":5,"metadata":{"execution":{"iopub.execute_input":"2023-11-22T19:57:52.455365Z","iopub.status.busy":"2023-11-22T19:57:52.454866Z","iopub.status.idle":"2023-11-22T19:57:52.464472Z","shell.execute_reply":"2023-11-22T19:57:52.463591Z","shell.execute_reply.started":"2023-11-22T19:57:52.455334Z"},"trusted":true},"outputs":[],"source":["script_args = ScriptArguments(\n"," local_rank=-1,\n"," per_device_train_batch_size=1, # custom value\n"," per_device_eval_batch_size=1,\n"," gradient_accumulation_steps=4,\n"," learning_rate=3e-5, # custom value\n"," max_grad_norm=0.3,\n"," weight_decay=0.01,\n"," lora_alpha=16,\n"," lora_dropout=0.1,\n"," lora_r=32,\n"," max_seq_length=512,\n"," # model_name=\"bn22/Mistral-7B-Instruct-v0.1-sharded\",\n"," model_name=\"mistralai/Mistral-7B-Instruct-v0.1\",\n"," dataset_name=\"iamtarun/python_code_instructions_18k_alpaca\",\n"," use_4bit=True,\n"," use_nested_quant=False,\n"," bnb_4bit_compute_dtype=\"float16\",\n"," bnb_4bit_quant_type=\"nf4\",\n"," num_train_epochs=100,\n"," fp16=True,\n"," bf16=False,\n"," packing=False,\n"," gradient_checkpointing=True,\n"," optim=\"paged_adamw_32bit\",\n"," lr_scheduler_type=\"constant\",\n"," max_steps=1000000,\n"," warmup_ratio=0.03,\n"," group_by_length=True,\n"," save_steps=50,\n"," logging_steps=50,\n"," merge_and_push=False,\n"," output_dir=\"./results_packing\"\n",")\n"]},{"cell_type":"markdown","metadata":{},"source":["## Data Preprocessing Utils"]},{"cell_type":"code","execution_count":6,"metadata":{"execution":{"iopub.execute_input":"2023-11-22T19:57:52.481305Z","iopub.status.busy":"2023-11-22T19:57:52.481017Z","iopub.status.idle":"2023-11-22T19:57:52.499029Z","shell.execute_reply":"2023-11-22T19:57:52.498298Z","shell.execute_reply.started":"2023-11-22T19:57:52.481282Z"},"trusted":true},"outputs":[],"source":["def gen_batches_train():\n"," ds = load_dataset(script_args.dataset_name, streaming=True, split=\"train\")\n"," total_samples = 10000\n"," val_pct = 0.1\n"," train_limit = int(total_samples * (1 - val_pct))\n"," counter = 0\n","\n"," for sample in iter(ds):\n"," if counter >= train_limit:\n"," break\n","\n"," original_prompt = sample['prompt'].replace(\"### Input:\\n\", '').replace('# Python code\\n', '')\n","\n"," instruction_start = original_prompt.find(\"### Instruction:\") + len(\"### Instruction:\")\n","\n"," instruction_end = original_prompt.find(\"### Output:\")\n","\n"," instruction = original_prompt[instruction_start:instruction_end].strip()\n","\n"," content_start = original_prompt.find(\"### Output:\") + len(\"### Output:\")\n","\n"," content = original_prompt[content_start:].strip()\n","\n"," new_text_format = f'<s>[INST] {instruction} [/INST] ```python\\n{content}```</s>'\n","\n"," tokenized_output = tokenizer(new_text_format)\n","\n"," yield {'text': new_text_format}\n","\n"," counter += 1\n","\n","def gen_batches_val():\n"," ds = load_dataset(script_args.dataset_name, streaming=True, split=\"train\")\n"," total_samples = 10000\n"," val_pct = 0.1\n"," train_limit = int(total_samples * (1 - val_pct))\n"," counter = 0\n","\n"," for sample in iter(ds):\n"," if counter < train_limit:\n"," counter += 1\n"," continue\n","\n"," if counter >= total_samples:\n"," break\n","\n"," original_prompt = sample['prompt'].replace(\"### Input:\\n\", '').replace('# Python code\\n', '')\n"," instruction_start = original_prompt.find(\"### Instruction:\") + len(\"### Instruction:\")\n"," instruction_end = original_prompt.find(\"### Output:\")\n","\n"," instruction = original_prompt[instruction_start:instruction_end].strip()\n"," content_start = original_prompt.find(\"### Output:\") + len(\"### Output:\")\n"," content = original_prompt[content_start:].strip()\n"," new_text_format = f'<s>[INST] {instruction} [/INST] ```python\\n{content}```</s>'\n","\n"," tokenized_output = tokenizer(new_text_format)\n"," yield {'text': new_text_format}\n","\n"," counter += 1"]},{"cell_type":"code","execution_count":7,"metadata":{},"outputs":[],"source":["\n","def create_and_prepare_model(args):\n"," compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype)\n","\n"," bnb_config = BitsAndBytesConfig(\n"," load_in_4bit=args.use_4bit,\n"," bnb_4bit_quant_type=args.bnb_4bit_quant_type,\n"," bnb_4bit_compute_dtype=compute_dtype,\n"," bnb_4bit_use_double_quant=args.use_nested_quant,\n"," )\n","\n"," if compute_dtype == torch.float16 and args.use_4bit:\n"," major, _ = torch.cuda.get_device_capability()\n"," if major >= 8:\n"," print(\"=\" * 80)\n"," print(\"Your GPU supports bfloat16, you can accelerate training with the argument --bf16\")\n"," print(\"=\" * 80)\n","\n"," # Load the entire model on the GPU 0\n"," # switch to `device_map = \"auto\"` for multi-GPU\n"," device_map = {\"\": 0}\n","\n"," model = AutoModelForCausalLM.from_pretrained(\n"," args.model_name,\n"," quantization_config=bnb_config,\n"," device_map=device_map,\n"," # use_auth_token=True,\n"," # revision=\"refs/pr/35\"\n"," )\n","\n"," #### LLAMA STUFF\n"," # check: https://github.com/huggingface/transformers/pull/24906\n"," model.config.pretraining_tp = 1\n"," # model.config.\n"," #### LLAMA STUFF\n"," model.config.window = 256\n","\n"," peft_config = LoraConfig(\n"," lora_alpha=script_args.lora_alpha,\n"," lora_dropout=script_args.lora_dropout,\n"," # target_modules=[\"query_key_value\"],\n"," r=script_args.lora_r,\n"," bias=\"none\",\n"," task_type=\"CAUSAL_LM\",\n"," target_modules=[\n"," \"q_proj\",\n"," \"k_proj\",\n"," \"v_proj\",\n"," \"o_proj\",\n"," ],\n"," )\n","\n"," tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, trust_remote_code=True)\n"," tokenizer.pad_token = tokenizer.eos_token\n","\n"," return model, peft_config, tokenizer"]},{"cell_type":"code","execution_count":8,"metadata":{"execution":{"iopub.execute_input":"2023-11-22T19:57:52.500253Z","iopub.status.busy":"2023-11-22T19:57:52.500015Z","iopub.status.idle":"2023-11-22T19:57:52.516188Z","shell.execute_reply":"2023-11-22T19:57:52.515464Z","shell.execute_reply.started":"2023-11-22T19:57:52.500232Z"},"trusted":true},"outputs":[],"source":["training_arguments = TrainingArguments(\n"," output_dir=script_args.output_dir,\n"," per_device_train_batch_size=script_args.per_device_train_batch_size,\n"," gradient_accumulation_steps=script_args.gradient_accumulation_steps,\n"," optim=script_args.optim,\n"," save_steps=script_args.save_steps,\n"," logging_steps=script_args.logging_steps,\n"," learning_rate=script_args.learning_rate,\n"," fp16=script_args.fp16,\n"," bf16=script_args.bf16,\n"," evaluation_strategy=\"steps\",\n"," max_grad_norm=script_args.max_grad_norm,\n"," max_steps=script_args.max_steps,\n"," warmup_ratio=script_args.warmup_ratio,\n"," group_by_length=script_args.group_by_length,\n"," lr_scheduler_type=script_args.lr_scheduler_type,\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["model, peft_config, tokenizer = create_and_prepare_model(script_args)\n","model.config.use_cache = False"]},{"cell_type":"code","execution_count":11,"metadata":{},"outputs":[{"name":"stderr","output_type":"stream","text":["Found cached dataset generator (/home/lin3060/.cache/huggingface/datasets/generator/default-d3a31d549d8b0b2f/0.0.0)\n","Found cached dataset generator (/home/lin3060/.cache/huggingface/datasets/generator/default-a3d63d20ba0037b9/0.0.0)\n"]}],"source":["train_gen = Dataset.from_generator(gen_batches_train)\n","\n","val_gen = Dataset.from_generator(gen_batches_val)"]},{"cell_type":"code","execution_count":12,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Dataset({\n"," features: ['text'],\n"," num_rows: 9000\n","})\n","Dataset({\n"," features: ['text'],\n"," num_rows: 1000\n","})\n"]}],"source":["print(train_gen)\n","\n","print(val_gen)"]},{"cell_type":"code","execution_count":13,"metadata":{"execution":{"iopub.execute_input":"2023-11-22T19:59:59.403244Z","iopub.status.busy":"2023-11-22T19:59:59.402960Z","iopub.status.idle":"2023-11-22T19:59:59.412016Z","shell.execute_reply":"2023-11-22T19:59:59.411122Z","shell.execute_reply.started":"2023-11-22T19:59:59.403221Z"},"trusted":true},"outputs":[{"data":{"text/plain":["MistralForCausalLM(\n"," (model): MistralModel(\n"," (embed_tokens): Embedding(32000, 4096)\n"," (layers): ModuleList(\n"," (0-31): 32 x MistralDecoderLayer(\n"," (self_attn): MistralAttention(\n"," (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)\n"," (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)\n"," (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)\n"," (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)\n"," (rotary_emb): MistralRotaryEmbedding()\n"," )\n"," (mlp): MistralMLP(\n"," (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)\n"," (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)\n"," (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)\n"," (act_fn): SiLUActivation()\n"," )\n"," (input_layernorm): MistralRMSNorm()\n"," (post_attention_layernorm): MistralRMSNorm()\n"," )\n"," )\n"," (norm): MistralRMSNorm()\n"," )\n"," (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n",")"]},"execution_count":13,"metadata":{},"output_type":"execute_result"}],"source":["model"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\n","# Fix weird overflow issue with fp16 training\n","tokenizer.padding_side = \"right\"\n","\n","trainer = SFTTrainer(\n"," model=model,\n"," train_dataset=train_gen,\n"," eval_dataset=val_gen,\n"," peft_config=peft_config,\n"," dataset_text_field=\"text\",\n"," max_seq_length=script_args.max_seq_length,\n"," tokenizer=tokenizer,\n"," args=training_arguments,\n"," packing=script_args.packing,\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2023-11-22T20:00:21.944998Z","iopub.status.busy":"2023-11-22T20:00:21.944602Z","iopub.status.idle":"2023-11-22T20:03:56.063718Z","shell.execute_reply":"2023-11-22T20:03:56.062228Z","shell.execute_reply.started":"2023-11-22T20:00:21.944962Z"},"trusted":true},"outputs":[],"source":["trainer.train()"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.status.busy":"2023-11-22T20:03:56.065298Z","iopub.status.idle":"2023-11-22T20:03:56.066145Z","shell.execute_reply":"2023-11-22T20:03:56.065900Z","shell.execute_reply.started":"2023-11-22T20:03:56.065875Z"},"trusted":true},"outputs":[],"source":["if script_args.merge_and_push:\n"," output_dir = os.path.join(script_args.output_dir, \"final_checkpoints\")\n"," trainer.model.save_pretrained(output_dir)\n","\n"," # Free memory for merging weights\n"," del model\n"," torch.cuda.empty_cache()\n","\n"," model = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map=\"auto\", torch_dtype=torch.bfloat16)\n"," model = model.merge_and_unload()\n","\n"," output_merged_dir = os.path.join(script_args.output_dir, \"Final_Model_Checkpoint\")\n"," model.save_pretrained(output_merged_dir, safe_serialization=True)"]},{"cell_type":"markdown","metadata":{},"source":["## Inference"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["from transformers import AutoModelForCausalLM, AutoTokenizer\n","import torch\n","\n","# Load the fine-tuned model and tokenizer\n","model_path = \"./results_packing/Final_Model_Checkpoint\" # Update this path to your model's location\n","tokenizer = AutoTokenizer.from_pretrained(model_path)\n","model = AutoModelForCausalLM.from_pretrained(model_path)\n","\n","# Function to generate text based on a prompt\n","def generate_text(prompt, max_length=50):\n"," # Encode the input prompt\n"," input_ids = tokenizer.encode(prompt, return_tensors='pt')\n","\n"," # Generate a response\n"," output = model.generate(input_ids, max_length=max_length, num_return_sequences=1)\n","\n"," # Decode and return the generated text\n"," return tokenizer.decode(output[0], skip_special_tokens=True)\n","\n","# Example usage\n","prompt = \"Your input prompt goes here\" # Replace with your input prompt\n","generated_text = generate_text(prompt)\n","print(generated_text)\n"]}],"metadata":{"kaggle":{"accelerator":"gpu","dataSources":[],"dockerImageVersionId":30588,"isGpuEnabled":true,"isInternetEnabled":true,"language":"python","sourceType":"notebook"},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.13"}},"nbformat":4,"nbformat_minor":4}
2