llava

Форк
0
/
merge_lora_weights.py 
22 строки · 767.0 Байт
1
import argparse
2
from llava.model.builder import load_pretrained_model
3
from llava.mm_utils import get_model_name_from_path
4

5

6
def merge_lora(args):
7
    model_name = get_model_name_from_path(args.model_path)
8
    tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, device_map='cpu')
9

10
    model.save_pretrained(args.save_model_path)
11
    tokenizer.save_pretrained(args.save_model_path)
12

13

14
if __name__ == "__main__":
15
    parser = argparse.ArgumentParser()
16
    parser.add_argument("--model-path", type=str, required=True)
17
    parser.add_argument("--model-base", type=str, required=True)
18
    parser.add_argument("--save-model-path", type=str, required=True)
19

20
    args = parser.parse_args()
21

22
    merge_lora(args)
23

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.