aurora

Форк
0
89 строк · 2.8 Кб
1
import os
2
import json
3
import gradio as gr
4
from typing import TYPE_CHECKING, Any, Dict
5
from datetime import datetime
6

7
from llmtuner.extras.packages import is_matplotlib_available
8
from llmtuner.extras.ploting import smooth
9
from llmtuner.webui.common import get_save_dir
10

11
if TYPE_CHECKING:
12
    from llmtuner.extras.callbacks import LogCallback
13

14
if is_matplotlib_available():
15
    import matplotlib.figure
16
    import matplotlib.pyplot as plt
17

18

19
def update_process_bar(callback: "LogCallback") -> Dict[str, Any]:
20
    if not callback.max_steps:
21
        return gr.update(visible=False)
22

23
    percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0
24
    label = "Running {:d}/{:d}: {} < {}".format(
25
        callback.cur_steps,
26
        callback.max_steps,
27
        callback.elapsed_time,
28
        callback.remaining_time
29
    )
30
    return gr.update(label=label, value=percentage, visible=True)
31

32

33
def get_time() -> str:
34
    return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
35

36

37
def can_quantize(finetuning_type: str) -> Dict[str, Any]:
38
    if finetuning_type != "lora":
39
        return gr.update(value="None", interactive=False)
40
    else:
41
        return gr.update(interactive=True)
42

43

44
def gen_cmd(args: Dict[str, Any]) -> str:
45
    args.pop("disable_tqdm", None)
46
    args["plot_loss"] = args.get("do_train", None)
47
    current_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
48
    cmd_lines = ["CUDA_VISIBLE_DEVICES={} python src/train_bash.py ".format(current_devices)]
49
    for k, v in args.items():
50
        if v is not None and v != "":
51
            cmd_lines.append("    --{} {} ".format(k, str(v)))
52
    cmd_text = "\\\n".join(cmd_lines)
53
    cmd_text = "```bash\n{}\n```".format(cmd_text)
54
    return cmd_text
55

56

57
def get_eval_results(path: os.PathLike) -> str:
58
    with open(path, "r", encoding="utf-8") as f:
59
        result = json.dumps(json.load(f), indent=4)
60
    return "```json\n{}\n```\n".format(result)
61

62

63
def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> "matplotlib.figure.Figure":
64
    if not base_model:
65
        return
66
    log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl")
67
    if not os.path.isfile(log_file):
68
        return
69

70
    plt.close("all")
71
    fig = plt.figure()
72
    ax = fig.add_subplot(111)
73
    steps, losses = [], []
74
    with open(log_file, "r", encoding="utf-8") as f:
75
        for line in f:
76
            log_info = json.loads(line)
77
            if log_info.get("loss", None):
78
                steps.append(log_info["current_steps"])
79
                losses.append(log_info["loss"])
80

81
    if len(losses) == 0:
82
        return None
83

84
    ax.plot(steps, losses, alpha=0.4, label="original")
85
    ax.plot(steps, smooth(losses), label="smoothed")
86
    ax.legend()
87
    ax.set_xlabel("step")
88
    ax.set_ylabel("loss")
89
    return fig
90

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.