aurora

Форк
0
109 строк · 5.1 Кб
1
import torch
2
from typing import TYPE_CHECKING, Optional, Union
3

4
from llmtuner.extras.logging import get_logger
5
from llmtuner.hparams import ModelArguments, FinetuningArguments
6
from llmtuner.model import get_modelcard_args, load_model_and_tokenizer, load_valuehead_params
7

8
if TYPE_CHECKING:
9
    from transformers import Seq2SeqTrainingArguments, Trainer
10
    from transformers.modeling_utils import PreTrainedModel
11
    from trl import AutoModelForCausalLMWithValueHead
12
    from llmtuner.hparams import DataArguments
13

14

15
logger = get_logger(__name__)
16

17

18
def create_modelcard_and_push(
19
    trainer: "Trainer",
20
    model_args: "ModelArguments",
21
    data_args: "DataArguments",
22
    training_args: "Seq2SeqTrainingArguments",
23
    finetuning_args: "FinetuningArguments"
24
) -> None:
25
    if training_args.do_train:
26
        if training_args.push_to_hub:
27
            trainer.push_to_hub(**get_modelcard_args(model_args, data_args, finetuning_args))
28
            return
29
        try:
30
            trainer.create_model_card(**get_modelcard_args(model_args, data_args, finetuning_args))
31
        except Exception as err:
32
            logger.warning("Failed to create model card: {}".format(str(err)))
33

34

35
def create_ref_model(
36
    model_args: "ModelArguments",
37
    finetuning_args: "FinetuningArguments",
38
    add_valuehead: Optional[bool] = False
39
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
40
    r"""
41
    Creates reference model for PPO/DPO training. Evaluation mode is not supported.
42

43
    The valuehead parameter is randomly initialized since it is useless for PPO training.
44
    """
45
    if finetuning_args.ref_model is not None:
46
        ref_model_args_dict = model_args.to_dict()
47
        ref_model_args_dict.update(dict(
48
            model_name_or_path=finetuning_args.ref_model,
49
            checkpoint_dir=finetuning_args.ref_model_checkpoint,
50
            quantization_bit=finetuning_args.ref_model_quantization_bit
51
        ))
52
        ref_model_args = ModelArguments(**ref_model_args_dict)
53
        ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
54
        ref_model, _ = load_model_and_tokenizer(
55
            ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
56
        )
57
        logger.info("Created reference model from {}".format(finetuning_args.ref_model))
58
    else:
59
        if finetuning_args.finetuning_type == "lora":
60
            ref_model = None
61
        else:
62
            ref_model, _ = load_model_and_tokenizer(
63
                model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead
64
            )
65
            logger.info("Created reference model from the model itself.")
66

67
    return ref_model
68

69

70
def create_reward_model(
71
    model: "AutoModelForCausalLMWithValueHead",
72
    model_args: "ModelArguments",
73
    finetuning_args: "FinetuningArguments"
74
) -> "AutoModelForCausalLMWithValueHead":
75
    r"""
76
    Creates reward model for PPO training.
77
    """
78
    if finetuning_args.reward_model_type == "api":
79
        assert finetuning_args.reward_model.startswith("http"), "Please provide full url."
80
        logger.info("Use reward server {}".format(finetuning_args.reward_model))
81
        return finetuning_args.reward_model
82
    elif finetuning_args.reward_model_type == "lora":
83
        model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
84
        for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
85
            if "default" in name:
86
                param.data = param.data.to(torch.float32) # trainable params should in fp32
87
        vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args)
88
        assert vhead_params is not None, "Reward model is not correctly loaded."
89
        model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
90
        model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
91
        model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
92
        model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
93
        logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
94
        return None
95
    else:
96
        reward_model_args_dict = model_args.to_dict()
97
        reward_model_args_dict.update(dict(
98
            model_name_or_path=finetuning_args.reward_model,
99
            checkpoint_dir=finetuning_args.reward_model_checkpoint,
100
            quantization_bit=finetuning_args.reward_model_quantization_bit
101
        ))
102
        reward_model_args = ModelArguments(**reward_model_args_dict)
103
        reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
104
        reward_model, _ = load_model_and_tokenizer(
105
            reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
106
        )
107
        logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model))
108
        logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
109
        return reward_model
110

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.