aurora

Форк
0
80 строк · 2.2 Кб
1
import gradio as gr
2
from typing import TYPE_CHECKING, Dict, Generator, List
3

4
from llmtuner.train import export_model
5
from llmtuner.webui.common import get_save_dir
6
from llmtuner.webui.locales import ALERTS
7

8
if TYPE_CHECKING:
9
    from gradio.components import Component
10
    from llmtuner.webui.engine import Engine
11

12

13
def save_model(
14
    lang: str,
15
    model_name: str,
16
    model_path: str,
17
    checkpoints: List[str],
18
    finetuning_type: str,
19
    template: str,
20
    max_shard_size: int,
21
    export_dir: str
22
) -> Generator[str, None, None]:
23
    error = ""
24
    if not model_name:
25
        error = ALERTS["err_no_model"][lang]
26
    elif not model_path:
27
        error = ALERTS["err_no_path"][lang]
28
    elif not checkpoints:
29
        error = ALERTS["err_no_checkpoint"][lang]
30
    elif not export_dir:
31
        error = ALERTS["err_no_export_dir"][lang]
32

33
    if error:
34
        gr.Warning(error)
35
        yield error
36
        return
37

38
    args = dict(
39
        model_name_or_path=model_path,
40
        checkpoint_dir=",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]),
41
        finetuning_type=finetuning_type,
42
        template=template,
43
        export_dir=export_dir,
44
        export_size=max_shard_size
45
    )
46

47
    yield ALERTS["info_exporting"][lang]
48
    export_model(args)
49
    yield ALERTS["info_exported"][lang]
50

51

52
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
53
    with gr.Row():
54
        export_dir = gr.Textbox()
55
        max_shard_size = gr.Slider(value=1, minimum=1, maximum=100)
56

57
    export_btn = gr.Button()
58
    info_box = gr.Textbox(show_label=False, interactive=False)
59

60
    export_btn.click(
61
        save_model,
62
        [
63
            engine.manager.get_elem_by_name("top.lang"),
64
            engine.manager.get_elem_by_name("top.model_name"),
65
            engine.manager.get_elem_by_name("top.model_path"),
66
            engine.manager.get_elem_by_name("top.checkpoints"),
67
            engine.manager.get_elem_by_name("top.finetuning_type"),
68
            engine.manager.get_elem_by_name("top.template"),
69
            max_shard_size,
70
            export_dir
71
        ],
72
        [info_box]
73
    )
74

75
    return dict(
76
        export_dir=export_dir,
77
        max_shard_size=max_shard_size,
78
        export_btn=export_btn,
79
        info_box=info_box
80
    )
81

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.