stanford_alpaca

Форк
0
/
weight_diff.py 
158 строк · 6.0 Кб
1
#    Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
2
#
3
#    Licensed under the Apache License, Version 2.0 (the "License");
4
#    you may not use this file except in compliance with the License.
5
#    You may obtain a copy of the License at
6
#
7
#        http://www.apache.org/licenses/LICENSE-2.0
8
#
9
#    Unless required by applicable law or agreed to in writing, software
10
#    distributed under the License is distributed on an "AS IS" BASIS,
11
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
#    See the License for the specific language governing permissions and
13
#    limitations under the License.
14

15
from typing import Optional
16

17
import fire
18
import torch
19
import tqdm
20
import transformers
21
from train import smart_tokenizer_and_embedding_resize
22

23

24
@torch.inference_mode()
25
def make_diff(
26
    path_raw: str, path_tuned: str, path_diff: str, device="cpu",  # "cuda" or "cpu"
27
):
28
    """Make the weight diff.
29

30
    This function is given to present full transparency of how the weight diff was created.
31

32
    Run:
33
        python weight_diff.py make_diff --path_raw <your_path_raw> --path_tuned <your_path_tuned> --path_diff <your_path_diff>
34
    """
35
    model_tuned: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained(
36
        path_tuned,
37
        device_map={"": torch.device(device)},
38
        torch_dtype=torch.float32,
39
        low_cpu_mem_usage=True,
40
    )
41
    model_raw: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained(
42
        path_raw,
43
        device_map={"": torch.device(device)},
44
        torch_dtype=torch.float32,
45
        low_cpu_mem_usage=True,
46
    )
47

48
    tokenizer_tuned: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
49
        path_tuned
50
    )
51
    tokenizer_raw: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
52
        path_raw
53
    )
54
    if tokenizer_raw.pad_token is None:
55
        smart_tokenizer_and_embedding_resize(
56
            special_tokens_dict=dict(pad_token="[PAD]"),
57
            model=model_raw,
58
            tokenizer=tokenizer_raw,
59
        )
60

61
    state_dict_tuned = model_tuned.state_dict()
62
    state_dict_raw = model_raw.state_dict()
63
    for key in tqdm.tqdm(state_dict_tuned):
64
        state_dict_tuned[key].add_(-state_dict_raw[key])
65

66
    model_tuned.save_pretrained(path_diff)
67
    tokenizer_tuned.save_pretrained(path_diff)
68

69

70
@torch.inference_mode()
71
def recover(
72
    path_raw,
73
    path_diff,
74
    path_tuned: Optional[str] = None,
75
    device="cpu",
76
    test_inference=True,
77
    check_integrity_naively=True,
78
):
79
    """Recover the original weights from the released weight diff.
80

81
    This function is given for you to run.
82

83
    Things to do before running this:
84
        1. Convert Meta's released weights into huggingface format. Follow this guide:
85
            https://huggingface.co/docs/transformers/main/model_doc/llama
86
        2. Make sure you cloned the released weight diff into your local machine. The weight diff is located at:
87
            https://huggingface.co/tatsu-lab/alpaca-7b/tree/main
88
        3. Run this function with the correct paths. E.g.,
89
            python weight_diff.py recover --path_raw <path_to_step_1_dir> --path_diff <path_to_step_2_dir>
90

91
    Additional notes:
92
        - If things run too slowly, and you have an 80G GPU lying around, let GPU go brrr by setting `--device "cuda"`.
93
        - If you want to save the recovered weights, set `--path_tuned <your_path_tuned>`.
94
            Next time you can load the recovered weights directly from `<your_path_tuned>`.
95
    """
96
    model_raw: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained(
97
        path_raw,
98
        device_map={"": torch.device(device)},
99
        torch_dtype=torch.float32,
100
        low_cpu_mem_usage=True,
101
    )
102
    model_recovered: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained(
103
        path_diff,
104
        device_map={"": torch.device(device)},
105
        torch_dtype=torch.float32,
106
        low_cpu_mem_usage=True,
107
    )
108

109
    tokenizer_raw: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
110
        path_raw
111
    )
112
    if tokenizer_raw.pad_token is None:
113
        smart_tokenizer_and_embedding_resize(
114
            special_tokens_dict=dict(pad_token="[PAD]"),
115
            model=model_raw,
116
            tokenizer=tokenizer_raw,
117
        )
118
    tokenizer_recovered: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
119
        path_diff
120
    )
121

122
    state_dict_recovered = model_recovered.state_dict()
123
    state_dict_raw = model_raw.state_dict()
124
    for key in tqdm.tqdm(state_dict_recovered):
125
        state_dict_recovered[key].add_(state_dict_raw[key])
126

127
    if check_integrity_naively:
128
        # This is not a rigorous, cryptographically strong integrity check :)
129
        allsum = sum(state_dict_recovered[key].sum() for key in state_dict_recovered)
130
        assert torch.allclose(
131
            allsum, torch.full_like(allsum, fill_value=50637.1836), atol=1e-2, rtol=0
132
        ), "Naive integrity check failed. This could imply that some of the checkpoint files are corrupted."
133

134
    if path_tuned is not None:
135
        model_recovered.save_pretrained(path_tuned)
136
        tokenizer_recovered.save_pretrained(path_tuned)
137

138
    if test_inference:
139
        input_text = (
140
            "Below is an instruction that describes a task. "
141
            "Write a response that appropriately completes the request.\r\n\r\n"
142
            "### Instruction:\r\nList three technologies that make life easier.\r\n\r\n### Response:"
143
        )
144
        inputs = tokenizer_recovered(input_text, return_tensors="pt")
145
        out = model_recovered.generate(inputs=inputs.input_ids, max_new_tokens=100)
146
        output_text = tokenizer_recovered.batch_decode(out, skip_special_tokens=True)[0]
147
        output_text = output_text[len(input_text) :]
148
        print(f"Input: {input_text}\nCompletion: {output_text}")
149

150
    return model_recovered, tokenizer_recovered
151

152

153
def main(task, **kwargs):
154
    globals()[task](**kwargs)
155

156

157
if __name__ == "__main__":
158
    fire.Fire(main)
159

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.