3
from typing import TYPE_CHECKING, Optional, List
4
from transformers import Seq2SeqTrainingArguments
6
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
7
from llmtuner.extras.constants import IGNORE_INDEX
8
from llmtuner.extras.ploting import plot_loss
9
from llmtuner.hparams import ModelArguments
10
from llmtuner.model import load_model_and_tokenizer
11
from llmtuner.train.dpo.collator import DPODataCollatorWithPadding
12
from llmtuner.train.dpo.trainer import CustomDPOTrainer
13
from llmtuner.train.utils import create_modelcard_and_push, create_ref_model
16
from transformers import TrainerCallback
17
from llmtuner.hparams import DataArguments, FinetuningArguments
21
model_args: "ModelArguments",
22
data_args: "DataArguments",
23
training_args: "Seq2SeqTrainingArguments",
24
finetuning_args: "FinetuningArguments",
25
callbacks: Optional[List["TrainerCallback"]] = None
27
dataset = get_dataset(model_args, data_args)
28
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
29
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
30
data_collator = DPODataCollatorWithPadding(
33
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
37
if finetuning_args.ref_model is None and (not training_args.do_train):
40
ref_model = create_ref_model(model_args, finetuning_args)
43
training_args_dict = training_args.to_dict()
44
training_args_dict.update(dict(remove_unused_columns=False))
45
training_args = Seq2SeqTrainingArguments(**training_args_dict)
48
trainer = CustomDPOTrainer(
49
beta=finetuning_args.dpo_beta,
54
data_collator=data_collator,
56
**split_dataset(dataset, data_args, training_args)
60
if training_args.do_train:
61
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
63
trainer.log_metrics("train", train_result.metrics)
64
trainer.save_metrics("train", train_result.metrics)
66
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
67
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
70
if training_args.do_eval:
71
metrics = trainer.evaluate(metric_key_prefix="eval")
72
if id(model) == id(ref_model):
73
remove_keys = [key for key in metrics.keys() if "rewards" in key]
74
for key in remove_keys:
76
trainer.log_metrics("eval", metrics)
77
trainer.save_metrics("eval", metrics)
80
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)