5
from safetensors import safe_open
6
from safetensors.torch import save_file
7
from typing import Any, ContextManager, cast
9
# Function to determine if file is a SafeTensor file
10
def is_safetensor_file(file_path):
11
return file_path.endswith('.safetensors')
14
# Unified loading function
15
def load_model(file_path):
16
if is_safetensor_file(file_path):
18
with cast(ContextManager[Any], safe_open(file_path, framework="pt", device="cpu")) as f:
20
tensors[key] = f.get_tensor(key).clone()
22
print(f"{key} : {tensors[key].shape}")
23
return tensors, 'safetensor'
25
return torch.load(file_path, map_location=torch.device('cpu')), 'pytorch'
28
# Unified saving function
29
def save_model(model, file_path, file_type):
30
if file_type == 'safetensor':
31
# safe_save(model, file_path)
32
save_file(model, file_path)
34
torch.save(model, file_path)
37
# Adapted function to clean vision tower from checkpoint
38
def clean_vision_tower_from_checkpoint(checkpoint_path):
39
checkpoint, file_type = load_model(checkpoint_path)
40
# file_type = 'pytorch'
41
model_path = os.path.dirname(checkpoint_path)
42
print(f"Searching for vision tower tensors in {checkpoint_path}")
43
clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") or k.startswith("vit."))]
45
if len(clip_tensors) > 0:
46
print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}")
47
# Adapted for file type
48
clip_path = os.path.join(model_path, "llava.clip")
50
if os.path.exists(clip_path):
51
print(f"Loading existing llava.clip from {clip_path}")
52
existing_clip, _ = load_model(clip_path)
54
print(f"Creating new llava.clip at {clip_path}")
56
# Update existing_clip with new tensors, avoid duplicates
57
for name in clip_tensors:
58
simple_name = name[name.index('vision_model.'):] if 'vision_model.' in name else name
59
print(f"Adding {simple_name} to llava.clip")
60
if simple_name not in existing_clip:
61
existing_clip[simple_name] = checkpoint[name]
63
# Save the updated clip tensors back to llava.clip
64
save_model(existing_clip, clip_path, 'pytorch')
66
# Remove the tensors from the original checkpoint
67
for name in clip_tensors:
70
checkpoint_path = checkpoint_path
74
def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector):
75
newline_checkpoint_path = None
76
projector_checkpoint_path = None
78
for path in checkpoint_paths:
79
checkpoint, _ = load_model(path)
80
if newline_criteria(checkpoint) and newline_checkpoint_path is None:
81
newline_checkpoint_path = path
82
if projector(checkpoint):
83
projector_checkpoint_path = path
85
return newline_checkpoint_path, projector_checkpoint_path
87
def newline_criteria(checkpoint):
88
return any(k.startswith("model.image_newline") for k in checkpoint.keys())
90
def proj_criteria(checkpoint):
91
return any(k.startswith("model.mm_projector") or k.startswith("vision_proj.") for k in checkpoint.keys())
94
# Command-line interface setup
95
ap = argparse.ArgumentParser()
96
ap.add_argument("-m", "--model", required=True, help="Path to LLaVA v1.5+ model")
97
ap.add_argument("-C", "--clean-vision-tower", action="store_true", help="Remove any vision tower from the model files")
100
if args.clean_vision_tower:
101
# Generalized to handle both PyTorch and SafeTensors models
102
model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True)
103
# checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and path.startswith('pytorch')) or (path.endswith('.safetensors') and path.startswith('model'))]
104
checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])]
105
for projector_checkpoint_path in checkpoint_paths:
106
print(f"Cleaning {projector_checkpoint_path}")
107
if not clean_vision_tower_from_checkpoint(projector_checkpoint_path):
108
print(f"No vision tower found in {projector_checkpoint_path}")
109
# we break once none is found, so far all models append them at the end
111
print("Done! All vision tower tensors are removed from the model files and stored in llava.clip file.")
113
# Now we look for the projector in the last checkpoint
114
model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True)
115
checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])]
116
# last_checkpoint_path = checkpoint_paths[0]
117
# first_checkpoint_path = checkpoint_paths[-1]
118
newline_checkpoint_path, projector_checkpoint_path = find_relevant_checkpoints(checkpoint_paths, newline_criteria, proj_criteria)
120
print(f"Taking projector from {projector_checkpoint_path}")
122
first_checkpoint = None
123
if newline_checkpoint_path is not None:
124
print(f"Taking newline from {newline_checkpoint_path}")
125
first_checkpoint, file_type = load_model(newline_checkpoint_path)
126
first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")]
130
last_checkpoint = None
131
if projector_checkpoint_path is not None:
132
last_checkpoint, file_type = load_model(projector_checkpoint_path)
133
mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")]
135
if len(mm_tensors) == 0:
136
if last_checkpoint is not None:
137
for k, v in last_checkpoint.items():
139
print(f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint) if last_checkpoint is not None else 0} tensors.")
140
print("No tensors found. Is this a LLaVA model?")
143
print(f"Found {len(mm_tensors)} tensors to extract.")
144
print(f"Found additional {len(first_mm_tensors)} tensors to extract.")
145
# projector = {name: checkpoint.[name].float() for name in mm_tensors}
147
for name in mm_tensors:
148
assert last_checkpoint is not None
149
projector[name] = last_checkpoint[name].float()
150
for name in first_mm_tensors:
151
assert first_checkpoint is not None
152
projector[name] = first_checkpoint[name].float()
154
if len(projector) > 0:
155
save_model(projector, f"{args.model}/llava.projector", 'pytorch')
158
print(f"Now you can convert {args.model} to a a regular LLaMA GGUF file.")
159
print(f"Also, use {args.model}/llava.projector to prepare a llava-encoder.gguf file.")