llama

Форк
0
/
llava_surgery_v2.py 
159 строк · 6.9 Кб
1
import argparse
2
import glob
3
import os
4
import torch
5
from safetensors import safe_open
6
from safetensors.torch import save_file
7
from typing import Any, ContextManager, cast
8

9
# Function to determine if file is a SafeTensor file
10
def is_safetensor_file(file_path):
11
    return file_path.endswith('.safetensors')
12

13

14
# Unified loading function
15
def load_model(file_path):
16
    if is_safetensor_file(file_path):
17
        tensors = {}
18
        with cast(ContextManager[Any], safe_open(file_path, framework="pt", device="cpu")) as f:
19
            for key in f.keys():
20
                tensors[key] = f.get_tensor(key).clone()
21
                # output shape
22
                print(f"{key} : {tensors[key].shape}")
23
        return tensors, 'safetensor'
24
    else:
25
        return torch.load(file_path, map_location=torch.device('cpu')), 'pytorch'
26

27

28
# Unified saving function
29
def save_model(model, file_path, file_type):
30
    if file_type == 'safetensor':
31
        # safe_save(model, file_path)
32
        save_file(model, file_path)
33
    else:
34
        torch.save(model, file_path)
35

36

37
# Adapted function to clean vision tower from checkpoint
38
def clean_vision_tower_from_checkpoint(checkpoint_path):
39
    checkpoint, file_type = load_model(checkpoint_path)
40
    # file_type = 'pytorch'
41
    model_path = os.path.dirname(checkpoint_path)
42
    print(f"Searching for vision tower tensors in {checkpoint_path}")
43
    clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") or k.startswith("vit."))]
44

45
    if len(clip_tensors) > 0:
46
        print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}")
47
        # Adapted for file type
48
        clip_path = os.path.join(model_path, "llava.clip")
49

50
        if os.path.exists(clip_path):
51
            print(f"Loading existing llava.clip from {clip_path}")
52
            existing_clip, _ = load_model(clip_path)
53
        else:
54
            print(f"Creating new llava.clip at {clip_path}")
55
            existing_clip = {}
56
        # Update existing_clip with new tensors, avoid duplicates
57
        for name in clip_tensors:
58
            simple_name = name[name.index('vision_model.'):] if 'vision_model.' in name else name
59
            print(f"Adding {simple_name} to llava.clip")
60
            if simple_name not in existing_clip:
61
                existing_clip[simple_name] = checkpoint[name]
62

63
        # Save the updated clip tensors back to llava.clip
64
        save_model(existing_clip, clip_path, 'pytorch')
65

66
        # Remove the tensors from the original checkpoint
67
        for name in clip_tensors:
68
            del checkpoint[name]
69

70
        checkpoint_path = checkpoint_path
71
        return True
72
    return False
73

74
def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector):
75
    newline_checkpoint_path = None
76
    projector_checkpoint_path = None
77

78
    for path in checkpoint_paths:
79
        checkpoint, _ = load_model(path)
80
        if newline_criteria(checkpoint) and newline_checkpoint_path is None:
81
            newline_checkpoint_path = path
82
        if projector(checkpoint):
83
            projector_checkpoint_path = path
84

85
    return newline_checkpoint_path, projector_checkpoint_path
86

87
def newline_criteria(checkpoint):
88
    return any(k.startswith("model.image_newline") for k in checkpoint.keys())
89

90
def proj_criteria(checkpoint):
91
    return any(k.startswith("model.mm_projector") or k.startswith("vision_proj.") for k in checkpoint.keys())
92

93

94
# Command-line interface setup
95
ap = argparse.ArgumentParser()
96
ap.add_argument("-m", "--model", required=True, help="Path to LLaVA v1.5+ model")
97
ap.add_argument("-C", "--clean-vision-tower", action="store_true", help="Remove any vision tower from the model files")
98
args = ap.parse_args()
99

100
if args.clean_vision_tower:
101
    # Generalized to handle both PyTorch and SafeTensors models
102
    model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True)
103
    # checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and path.startswith('pytorch')) or (path.endswith('.safetensors') and path.startswith('model'))]
104
    checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])]
105
    for projector_checkpoint_path in checkpoint_paths:
106
        print(f"Cleaning {projector_checkpoint_path}")
107
        if not clean_vision_tower_from_checkpoint(projector_checkpoint_path):
108
            print(f"No vision tower found in {projector_checkpoint_path}")
109
            # we break once none is found, so far all models append them at the end
110
            # break
111
    print("Done! All vision tower tensors are removed from the model files and stored in llava.clip file.")
112

113
# Now we look for the projector in the last checkpoint
114
model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True)
115
checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])]
116
# last_checkpoint_path = checkpoint_paths[0]
117
# first_checkpoint_path = checkpoint_paths[-1]
118
newline_checkpoint_path, projector_checkpoint_path = find_relevant_checkpoints(checkpoint_paths, newline_criteria, proj_criteria)
119

120
print(f"Taking projector from {projector_checkpoint_path}")
121
first_mm_tensors = []
122
first_checkpoint = None
123
if newline_checkpoint_path is not None:
124
    print(f"Taking newline from {newline_checkpoint_path}")
125
    first_checkpoint, file_type = load_model(newline_checkpoint_path)
126
    first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")]
127

128
# Load the checkpoint
129
mm_tensors = []
130
last_checkpoint = None
131
if projector_checkpoint_path is not None:
132
    last_checkpoint, file_type = load_model(projector_checkpoint_path)
133
    mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")]
134

135
if len(mm_tensors) == 0:
136
    if last_checkpoint is not None:
137
        for k, v in last_checkpoint.items():
138
            print(k)
139
    print(f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint) if last_checkpoint is not None else 0} tensors.")
140
    print("No tensors found. Is this a LLaVA model?")
141
    exit()
142

143
print(f"Found {len(mm_tensors)} tensors to extract.")
144
print(f"Found additional {len(first_mm_tensors)} tensors to extract.")
145
# projector = {name: checkpoint.[name].float() for name in mm_tensors}
146
projector = {}
147
for name in mm_tensors:
148
    assert last_checkpoint is not None
149
    projector[name] = last_checkpoint[name].float()
150
for name in first_mm_tensors:
151
    assert first_checkpoint is not None
152
    projector[name] = first_checkpoint[name].float()
153

154
if len(projector) > 0:
155
    save_model(projector, f"{args.model}/llava.projector", 'pytorch')
156

157
print("Done!")
158
print(f"Now you can convert {args.model} to a a regular LLaMA GGUF file.")
159
print(f"Also, use {args.model}/llava.projector to prepare a llava-encoder.gguf file.")
160

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.