aurora

Форк
0
56 строк · 2.3 Кб
1
from typing import TYPE_CHECKING, Any, Dict, List, Optional
2

3
from llmtuner.extras.callbacks import LogCallback
4
from llmtuner.extras.logging import get_logger
5
from llmtuner.model import get_train_args, get_infer_args, load_model_and_tokenizer
6
from llmtuner.train.pt import run_pt
7
from llmtuner.train.sft import run_sft
8
from llmtuner.train.rm import run_rm
9
from llmtuner.train.ppo import run_ppo
10
from llmtuner.train.dpo import run_dpo
11

12
if TYPE_CHECKING:
13
    from transformers import TrainerCallback
14

15

16
logger = get_logger(__name__)
17

18

19
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
20
    model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
21
    callbacks = [LogCallback()] if callbacks is None else callbacks
22

23
    if finetuning_args.stage == "pt":
24
        run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
25
    elif finetuning_args.stage == "sft":
26
        run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
27
    elif finetuning_args.stage == "rm":
28
        run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
29
    elif finetuning_args.stage == "ppo":
30
        run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
31
    elif finetuning_args.stage == "dpo":
32
        run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
33
    else:
34
        raise ValueError("Unknown task.")
35

36

37
def export_model(args: Optional[Dict[str, Any]] = None):
38
    model_args, _, finetuning_args, _ = get_infer_args(args)
39
    model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
40

41
    if getattr(model, "quantization_method", None) in ["gptq", "awq"]:
42
        raise ValueError("Cannot export a GPTQ or AWQ quantized model.")
43

44
    model.config.use_cache = True
45
    model.save_pretrained(finetuning_args.export_dir, max_shard_size="{}GB".format(finetuning_args.export_size))
46

47
    try:
48
        tokenizer.padding_side = "left" # restore padding side
49
        tokenizer.init_kwargs["padding_side"] = "left"
50
        tokenizer.save_pretrained(finetuning_args.export_dir)
51
    except:
52
        logger.warning("Cannot save tokenizer, please copy the files manually.")
53

54

55
if __name__ == "__main__":
56
    run_exp()
57

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.