stanford_alpaca
/
weight_diff.py
158 строк · 6.0 Кб
1# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import Optional
16
17import fire
18import torch
19import tqdm
20import transformers
21from train import smart_tokenizer_and_embedding_resize
22
23
24@torch.inference_mode()
25def make_diff(
26path_raw: str, path_tuned: str, path_diff: str, device="cpu", # "cuda" or "cpu"
27):
28"""Make the weight diff.
29
30This function is given to present full transparency of how the weight diff was created.
31
32Run:
33python weight_diff.py make_diff --path_raw <your_path_raw> --path_tuned <your_path_tuned> --path_diff <your_path_diff>
34"""
35model_tuned: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained(
36path_tuned,
37device_map={"": torch.device(device)},
38torch_dtype=torch.float32,
39low_cpu_mem_usage=True,
40)
41model_raw: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained(
42path_raw,
43device_map={"": torch.device(device)},
44torch_dtype=torch.float32,
45low_cpu_mem_usage=True,
46)
47
48tokenizer_tuned: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
49path_tuned
50)
51tokenizer_raw: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
52path_raw
53)
54if tokenizer_raw.pad_token is None:
55smart_tokenizer_and_embedding_resize(
56special_tokens_dict=dict(pad_token="[PAD]"),
57model=model_raw,
58tokenizer=tokenizer_raw,
59)
60
61state_dict_tuned = model_tuned.state_dict()
62state_dict_raw = model_raw.state_dict()
63for key in tqdm.tqdm(state_dict_tuned):
64state_dict_tuned[key].add_(-state_dict_raw[key])
65
66model_tuned.save_pretrained(path_diff)
67tokenizer_tuned.save_pretrained(path_diff)
68
69
70@torch.inference_mode()
71def recover(
72path_raw,
73path_diff,
74path_tuned: Optional[str] = None,
75device="cpu",
76test_inference=True,
77check_integrity_naively=True,
78):
79"""Recover the original weights from the released weight diff.
80
81This function is given for you to run.
82
83Things to do before running this:
841. Convert Meta's released weights into huggingface format. Follow this guide:
85https://huggingface.co/docs/transformers/main/model_doc/llama
862. Make sure you cloned the released weight diff into your local machine. The weight diff is located at:
87https://huggingface.co/tatsu-lab/alpaca-7b/tree/main
883. Run this function with the correct paths. E.g.,
89python weight_diff.py recover --path_raw <path_to_step_1_dir> --path_diff <path_to_step_2_dir>
90
91Additional notes:
92- If things run too slowly, and you have an 80G GPU lying around, let GPU go brrr by setting `--device "cuda"`.
93- If you want to save the recovered weights, set `--path_tuned <your_path_tuned>`.
94Next time you can load the recovered weights directly from `<your_path_tuned>`.
95"""
96model_raw: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained(
97path_raw,
98device_map={"": torch.device(device)},
99torch_dtype=torch.float32,
100low_cpu_mem_usage=True,
101)
102model_recovered: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained(
103path_diff,
104device_map={"": torch.device(device)},
105torch_dtype=torch.float32,
106low_cpu_mem_usage=True,
107)
108
109tokenizer_raw: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
110path_raw
111)
112if tokenizer_raw.pad_token is None:
113smart_tokenizer_and_embedding_resize(
114special_tokens_dict=dict(pad_token="[PAD]"),
115model=model_raw,
116tokenizer=tokenizer_raw,
117)
118tokenizer_recovered: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
119path_diff
120)
121
122state_dict_recovered = model_recovered.state_dict()
123state_dict_raw = model_raw.state_dict()
124for key in tqdm.tqdm(state_dict_recovered):
125state_dict_recovered[key].add_(state_dict_raw[key])
126
127if check_integrity_naively:
128# This is not a rigorous, cryptographically strong integrity check :)
129allsum = sum(state_dict_recovered[key].sum() for key in state_dict_recovered)
130assert torch.allclose(
131allsum, torch.full_like(allsum, fill_value=50637.1836), atol=1e-2, rtol=0
132), "Naive integrity check failed. This could imply that some of the checkpoint files are corrupted."
133
134if path_tuned is not None:
135model_recovered.save_pretrained(path_tuned)
136tokenizer_recovered.save_pretrained(path_tuned)
137
138if test_inference:
139input_text = (
140"Below is an instruction that describes a task. "
141"Write a response that appropriately completes the request.\r\n\r\n"
142"### Instruction:\r\nList three technologies that make life easier.\r\n\r\n### Response:"
143)
144inputs = tokenizer_recovered(input_text, return_tensors="pt")
145out = model_recovered.generate(inputs=inputs.input_ids, max_new_tokens=100)
146output_text = tokenizer_recovered.batch_decode(out, skip_special_tokens=True)[0]
147output_text = output_text[len(input_text) :]
148print(f"Input: {input_text}\nCompletion: {output_text}")
149
150return model_recovered, tokenizer_recovered
151
152
153def main(task, **kwargs):
154globals()[task](**kwargs)
155
156
157if __name__ == "__main__":
158fire.Fire(main)
159