2
from typing import TYPE_CHECKING, Optional, Union
4
from llmtuner.extras.logging import get_logger
5
from llmtuner.hparams import ModelArguments, FinetuningArguments
6
from llmtuner.model import get_modelcard_args, load_model_and_tokenizer, load_valuehead_params
9
from transformers import Seq2SeqTrainingArguments, Trainer
10
from transformers.modeling_utils import PreTrainedModel
11
from trl import AutoModelForCausalLMWithValueHead
12
from llmtuner.hparams import DataArguments
15
logger = get_logger(__name__)
18
def create_modelcard_and_push(
20
model_args: "ModelArguments",
21
data_args: "DataArguments",
22
training_args: "Seq2SeqTrainingArguments",
23
finetuning_args: "FinetuningArguments"
25
if training_args.do_train:
26
if training_args.push_to_hub:
27
trainer.push_to_hub(**get_modelcard_args(model_args, data_args, finetuning_args))
30
trainer.create_model_card(**get_modelcard_args(model_args, data_args, finetuning_args))
31
except Exception as err:
32
logger.warning("Failed to create model card: {}".format(str(err)))
36
model_args: "ModelArguments",
37
finetuning_args: "FinetuningArguments",
38
add_valuehead: Optional[bool] = False
39
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
41
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
43
The valuehead parameter is randomly initialized since it is useless for PPO training.
45
if finetuning_args.ref_model is not None:
46
ref_model_args_dict = model_args.to_dict()
47
ref_model_args_dict.update(dict(
48
model_name_or_path=finetuning_args.ref_model,
49
checkpoint_dir=finetuning_args.ref_model_checkpoint,
50
quantization_bit=finetuning_args.ref_model_quantization_bit
52
ref_model_args = ModelArguments(**ref_model_args_dict)
53
ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
54
ref_model, _ = load_model_and_tokenizer(
55
ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
57
logger.info("Created reference model from {}".format(finetuning_args.ref_model))
59
if finetuning_args.finetuning_type == "lora":
62
ref_model, _ = load_model_and_tokenizer(
63
model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead
65
logger.info("Created reference model from the model itself.")
70
def create_reward_model(
71
model: "AutoModelForCausalLMWithValueHead",
72
model_args: "ModelArguments",
73
finetuning_args: "FinetuningArguments"
74
) -> "AutoModelForCausalLMWithValueHead":
76
Creates reward model for PPO training.
78
if finetuning_args.reward_model_type == "api":
79
assert finetuning_args.reward_model.startswith("http"), "Please provide full url."
80
logger.info("Use reward server {}".format(finetuning_args.reward_model))
81
return finetuning_args.reward_model
82
elif finetuning_args.reward_model_type == "lora":
83
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
84
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
86
param.data = param.data.to(torch.float32) # trainable params should in fp32
87
vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args)
88
assert vhead_params is not None, "Reward model is not correctly loaded."
89
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
90
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
91
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
92
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
93
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
96
reward_model_args_dict = model_args.to_dict()
97
reward_model_args_dict.update(dict(
98
model_name_or_path=finetuning_args.reward_model,
99
checkpoint_dir=finetuning_args.reward_model_checkpoint,
100
quantization_bit=finetuning_args.reward_model_quantization_bit
102
reward_model_args = ModelArguments(**reward_model_args_dict)
103
reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
104
reward_model, _ = load_model_and_tokenizer(
105
reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
107
logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model))
108
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")