aurora

Форк
0
/
runner.py 
258 строк · 10.9 Кб
1
import os
2
import time
3
import logging
4
import gradio as gr
5
from threading import Thread
6
from gradio.components import Component # cannot use TYPE_CHECKING here
7
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple
8

9
import transformers
10
from transformers.trainer import TRAINING_ARGS_NAME
11

12
from llmtuner.extras.callbacks import LogCallback
13
from llmtuner.extras.constants import TRAINING_STAGES
14
from llmtuner.extras.logging import LoggerHandler
15
from llmtuner.extras.misc import torch_gc
16
from llmtuner.train import run_exp
17
from llmtuner.webui.common import get_module, get_save_dir, load_config
18
from llmtuner.webui.locales import ALERTS
19
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
20

21
if TYPE_CHECKING:
22
    from llmtuner.webui.manager import Manager
23

24

25
class Runner:
26

27
    def __init__(self, manager: "Manager", demo_mode: Optional[bool] = False) -> None:
28
        self.manager = manager
29
        self.demo_mode = demo_mode
30
        """ Resume """
31
        self.thread: "Thread" = None
32
        self.do_train = True
33
        self.running_data: Dict["Component", Any] = None
34
        """ State """
35
        self.aborted = False
36
        self.running = False
37
        """ Handler """
38
        self.logger_handler = LoggerHandler()
39
        self.logger_handler.setLevel(logging.INFO)
40
        logging.root.addHandler(self.logger_handler)
41
        transformers.logging.add_handler(self.logger_handler)
42

43
    @property
44
    def alive(self) -> bool:
45
        return self.thread is not None
46

47
    def set_abort(self) -> None:
48
        self.aborted = True
49

50
    def _initialize(self, data: Dict[Component, Any], do_train: bool, from_preview: bool) -> str:
51
        get = lambda name: data[self.manager.get_elem_by_name(name)]
52
        lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
53
        dataset = get("train.dataset") if do_train else get("eval.dataset")
54

55
        if self.running:
56
            return ALERTS["err_conflict"][lang]
57

58
        if not model_name:
59
            return ALERTS["err_no_model"][lang]
60

61
        if not model_path:
62
            return ALERTS["err_no_path"][lang]
63

64
        if len(dataset) == 0:
65
            return ALERTS["err_no_dataset"][lang]
66

67
        if self.demo_mode and (not from_preview):
68
            return ALERTS["err_demo"][lang]
69

70
        self.aborted = False
71
        self.logger_handler.reset()
72
        self.trainer_callback = LogCallback(self)
73
        return ""
74

75
    def _finalize(self, lang: str, finish_info: str) -> str:
76
        self.thread = None
77
        self.running_data = None
78
        self.running = False
79
        torch_gc()
80
        if self.aborted:
81
            return ALERTS["info_aborted"][lang]
82
        else:
83
            return finish_info
84

85
    def _parse_train_args(self, data: Dict[Component, Any]) -> Dict[str, Any]:
86
        get = lambda name: data[self.manager.get_elem_by_name(name)]
87
        user_config = load_config()
88

89
        if get("top.checkpoints"):
90
            checkpoint_dir = ",".join([
91
                get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
92
            ])
93
        else:
94
            checkpoint_dir = None
95

96
        args = dict(
97
            stage=TRAINING_STAGES[get("train.training_stage")],
98
            model_name_or_path=get("top.model_path"),
99
            do_train=True,
100
            cache_dir=user_config.get("cache_dir", None),
101
            checkpoint_dir=checkpoint_dir,
102
            finetuning_type=get("top.finetuning_type"),
103
            quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
104
            template=get("top.template"),
105
            flash_attn=get("top.flash_attn"),
106
            shift_attn=get("top.shift_attn"),
107
            rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
108
            dataset_dir=get("train.dataset_dir"),
109
            dataset=",".join(get("train.dataset")),
110
            cutoff_len=get("train.cutoff_len"),
111
            learning_rate=float(get("train.learning_rate")),
112
            num_train_epochs=float(get("train.num_train_epochs")),
113
            max_samples=int(get("train.max_samples")),
114
            per_device_train_batch_size=get("train.batch_size"),
115
            gradient_accumulation_steps=get("train.gradient_accumulation_steps"),
116
            lr_scheduler_type=get("train.lr_scheduler_type"),
117
            max_grad_norm=float(get("train.max_grad_norm")),
118
            logging_steps=get("train.logging_steps"),
119
            save_steps=get("train.save_steps"),
120
            warmup_steps=get("train.warmup_steps"),
121
            neftune_noise_alpha=get("train.neftune_alpha"),
122
            train_on_prompt=get("train.train_on_prompt"),
123
            upcast_layernorm=get("train.upcast_layernorm"),
124
            lora_rank=get("train.lora_rank"),
125
            lora_dropout=get("train.lora_dropout"),
126
            lora_target=get("train.lora_target") or get_module(get("top.model_name")),
127
            additional_target=get("train.additional_target") if get("train.additional_target") else None,
128
            resume_lora_training=get("train.resume_lora_training"),
129
            output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir"))
130
        )
131
        args[get("train.compute_type")] = True
132
        args["disable_tqdm"] = True
133

134
        if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]:
135
            args["resume_lora_training"] = (args["quantization_bit"] is not None)
136

137
        if args["quantization_bit"] is not None:
138
            args["upcast_layernorm"] = True
139

140
        if args["stage"] == "ppo":
141
            args["reward_model"] = get_save_dir(
142
                get("top.model_name"), get("top.finetuning_type"), get("train.reward_model")
143
            )
144
            args["reward_model_type"] = "lora" if get("top.finetuning_type") == "lora" else "full"
145

146
        if args["stage"] == "dpo":
147
            args["dpo_beta"] = get("train.dpo_beta")
148

149
        if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
150
            args["val_size"] = get("train.val_size")
151
            args["evaluation_strategy"] = "steps"
152
            args["eval_steps"] = get("train.save_steps")
153
            args["load_best_model_at_end"] = True
154

155
        return args
156

157
    def _parse_eval_args(self, data: Dict[Component, Any]) -> Dict[str, Any]:
158
        get = lambda name: data[self.manager.get_elem_by_name(name)]
159
        user_config = load_config()
160

161
        if get("top.checkpoints"):
162
            checkpoint_dir = ",".join([
163
                get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
164
            ])
165
        else:
166
            checkpoint_dir = None
167

168
        args = dict(
169
            stage="sft",
170
            model_name_or_path=get("top.model_path"),
171
            do_eval=True,
172
            predict_with_generate=True,
173
            cache_dir=user_config.get("cache_dir", None),
174
            checkpoint_dir=checkpoint_dir,
175
            finetuning_type=get("top.finetuning_type"),
176
            quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
177
            template=get("top.template"),
178
            flash_attn=get("top.flash_attn"),
179
            shift_attn=get("top.shift_attn"),
180
            rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
181
            dataset_dir=get("eval.dataset_dir"),
182
            dataset=",".join(get("eval.dataset")),
183
            cutoff_len=get("eval.cutoff_len"),
184
            max_samples=int(get("eval.max_samples")),
185
            per_device_eval_batch_size=get("eval.batch_size"),
186
            max_new_tokens=get("eval.max_new_tokens"),
187
            top_p=get("eval.top_p"),
188
            temperature=get("eval.temperature"),
189
            output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir"))
190
        )
191

192
        if get("eval.predict"):
193
            args.pop("do_eval", None)
194
            args["do_predict"] = True
195

196
        return args
197

198
    def _preview(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
199
        error = self._initialize(data, do_train, from_preview=True)
200
        if error:
201
            gr.Warning(error)
202
            yield error, gr.update(visible=False)
203
        else:
204
            args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
205
            yield gen_cmd(args), gr.update(visible=False)
206

207
    def _launch(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
208
        error = self._initialize(data, do_train, from_preview=False)
209
        if error:
210
            gr.Warning(error)
211
            yield error, gr.update(visible=False)
212
        else:
213
            args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
214
            run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
215
            self.do_train, self.running_data = do_train, data
216
            self.thread = Thread(target=run_exp, kwargs=run_kwargs)
217
            self.thread.start()
218
            yield from self.monitor()
219

220
    def preview_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
221
        yield from self._preview(data, do_train=True)
222

223
    def preview_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
224
        yield from self._preview(data, do_train=False)
225

226
    def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
227
        yield from self._launch(data, do_train=True)
228

229
    def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
230
        yield from self._launch(data, do_train=False)
231

232
    def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
233
        get = lambda name: self.running_data[self.manager.get_elem_by_name(name)]
234
        self.running = True
235
        lang = get("top.lang")
236
        output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get(
237
            "{}.output_dir".format("train" if self.do_train else "eval")
238
        ))
239

240
        while self.thread.is_alive():
241
            time.sleep(2)
242
            if self.aborted:
243
                yield ALERTS["info_aborting"][lang], gr.update(visible=False)
244
            else:
245
                yield self.logger_handler.log, update_process_bar(self.trainer_callback)
246

247
        if self.do_train:
248
            if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
249
                finish_info = ALERTS["info_finished"][lang]
250
            else:
251
                finish_info = ALERTS["err_failed"][lang]
252
        else:
253
            if os.path.exists(os.path.join(output_dir, "all_results.json")):
254
                finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
255
            else:
256
                finish_info = ALERTS["err_failed"][lang]
257

258
        yield self._finalize(lang, finish_info), gr.update(visible=False)
259

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.