4
from typing import TYPE_CHECKING
5
from datetime import timedelta
7
from transformers import TrainerCallback
8
from transformers.modeling_utils import custom_object_save, unwrap_model
9
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
11
from llmtuner.extras.constants import LOG_FILE_NAME
12
from llmtuner.extras.logging import get_logger
15
from transformers import TrainingArguments, TrainerState, TrainerControl
16
from trl import AutoModelForCausalLMWithValueHead
19
logger = get_logger(__name__)
22
def _save_model_with_valuehead(model: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
23
model.pretrained_model.config.save_pretrained(output_dir)
24
if model.pretrained_model.can_generate():
25
model.pretrained_model.generation_config.save_pretrained(output_dir)
26
if getattr(model, "is_peft_model", False):
27
model.pretrained_model.save_pretrained(output_dir)
28
elif getattr(model.pretrained_model, "_auto_class", None):
29
custom_object_save(model.pretrained_model, output_dir, config=model.pretrained_model.config)
32
class SavePeftModelCallback(TrainerCallback):
34
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
36
Event called after a checkpoint save.
39
_save_model_with_valuehead(
40
model=unwrap_model(kwargs.pop("model")),
41
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
44
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
46
Event called at the end of training.
49
_save_model_with_valuehead(model=unwrap_model(kwargs.pop("model")), output_dir=args.output_dir)
52
class LogCallback(TrainerCallback):
54
def __init__(self, runner=None):
56
self.in_training = False
57
self.start_time = time.time()
60
self.elapsed_time = ""
61
self.remaining_time = ""
64
cur_time = time.time()
65
elapsed_time = cur_time - self.start_time
66
avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0
67
remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step
68
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
69
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
71
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
73
Event called at the beginning of training.
75
if state.is_local_process_zero:
76
self.in_training = True
77
self.start_time = time.time()
78
self.max_steps = state.max_steps
79
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
80
logger.warning("Previous log file in this folder will be deleted.")
81
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
83
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
85
Event called at the end of training.
87
if state.is_local_process_zero:
88
self.in_training = False
92
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
94
Event called at the end of an substep during gradient accumulation.
96
if state.is_local_process_zero and self.runner is not None and self.runner.aborted:
97
control.should_epoch_stop = True
98
control.should_training_stop = True
100
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
102
Event called at the end of a training step.
104
if state.is_local_process_zero:
105
self.cur_steps = state.global_step
107
if self.runner is not None and self.runner.aborted:
108
control.should_epoch_stop = True
109
control.should_training_stop = True
111
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
113
Event called after an evaluation phase.
115
if state.is_local_process_zero and not self.in_training:
119
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs):
121
Event called after a successful prediction.
123
if state.is_local_process_zero and not self.in_training:
127
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
129
Event called after logging the last logs.
131
if not state.is_local_process_zero:
135
current_steps=self.cur_steps,
136
total_steps=self.max_steps,
137
loss=state.log_history[-1].get("loss", None),
138
eval_loss=state.log_history[-1].get("eval_loss", None),
139
predict_loss=state.log_history[-1].get("predict_loss", None),
140
reward=state.log_history[-1].get("reward", None),
141
learning_rate=state.log_history[-1].get("learning_rate", None),
142
epoch=state.log_history[-1].get("epoch", None),
143
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
144
elapsed_time=self.elapsed_time,
145
remaining_time=self.remaining_time
147
if self.runner is not None:
148
logger.info("{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
149
logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
152
os.makedirs(args.output_dir, exist_ok=True)
153
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
154
f.write(json.dumps(logs) + "\n")
156
def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
158
Event called after a prediction step.
160
eval_dataloader = kwargs.pop("eval_dataloader", None)
161
if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training:
162
if self.max_steps == 0:
163
self.max_steps = len(eval_dataloader)