aurora

Форк
0
/
callbacks.py 
165 строк · 6.9 Кб
1
import os
2
import json
3
import time
4
from typing import TYPE_CHECKING
5
from datetime import timedelta
6

7
from transformers import TrainerCallback
8
from transformers.modeling_utils import custom_object_save, unwrap_model
9
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
10

11
from llmtuner.extras.constants import LOG_FILE_NAME
12
from llmtuner.extras.logging import get_logger
13

14
if TYPE_CHECKING:
15
    from transformers import TrainingArguments, TrainerState, TrainerControl
16
    from trl import AutoModelForCausalLMWithValueHead
17

18

19
logger = get_logger(__name__)
20

21

22
def _save_model_with_valuehead(model: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
23
    model.pretrained_model.config.save_pretrained(output_dir)
24
    if model.pretrained_model.can_generate():
25
        model.pretrained_model.generation_config.save_pretrained(output_dir)
26
    if getattr(model, "is_peft_model", False):
27
        model.pretrained_model.save_pretrained(output_dir)
28
    elif getattr(model.pretrained_model, "_auto_class", None): # must not a peft model
29
        custom_object_save(model.pretrained_model, output_dir, config=model.pretrained_model.config)
30

31

32
class SavePeftModelCallback(TrainerCallback):
33

34
    def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
35
        r"""
36
        Event called after a checkpoint save.
37
        """
38
        if args.should_save:
39
            _save_model_with_valuehead(
40
                model=unwrap_model(kwargs.pop("model")),
41
                output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
42
            )
43

44
    def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
45
        r"""
46
        Event called at the end of training.
47
        """
48
        if args.should_save:
49
            _save_model_with_valuehead(model=unwrap_model(kwargs.pop("model")), output_dir=args.output_dir)
50

51

52
class LogCallback(TrainerCallback):
53

54
    def __init__(self, runner=None):
55
        self.runner = runner
56
        self.in_training = False
57
        self.start_time = time.time()
58
        self.cur_steps = 0
59
        self.max_steps = 0
60
        self.elapsed_time = ""
61
        self.remaining_time = ""
62

63
    def timing(self):
64
        cur_time = time.time()
65
        elapsed_time = cur_time - self.start_time
66
        avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0
67
        remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step
68
        self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
69
        self.remaining_time = str(timedelta(seconds=int(remaining_time)))
70

71
    def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
72
        r"""
73
        Event called at the beginning of training.
74
        """
75
        if state.is_local_process_zero:
76
            self.in_training = True
77
            self.start_time = time.time()
78
            self.max_steps = state.max_steps
79
            if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
80
                logger.warning("Previous log file in this folder will be deleted.")
81
                os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
82

83
    def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
84
        r"""
85
        Event called at the end of training.
86
        """
87
        if state.is_local_process_zero:
88
            self.in_training = False
89
            self.cur_steps = 0
90
            self.max_steps = 0
91

92
    def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
93
        r"""
94
        Event called at the end of an substep during gradient accumulation.
95
        """
96
        if state.is_local_process_zero and self.runner is not None and self.runner.aborted:
97
            control.should_epoch_stop = True
98
            control.should_training_stop = True
99

100
    def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
101
        r"""
102
        Event called at the end of a training step.
103
        """
104
        if state.is_local_process_zero:
105
            self.cur_steps = state.global_step
106
            self.timing()
107
            if self.runner is not None and self.runner.aborted:
108
                control.should_epoch_stop = True
109
                control.should_training_stop = True
110

111
    def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
112
        r"""
113
        Event called after an evaluation phase.
114
        """
115
        if state.is_local_process_zero and not self.in_training:
116
            self.cur_steps = 0
117
            self.max_steps = 0
118

119
    def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs):
120
        r"""
121
        Event called after a successful prediction.
122
        """
123
        if state.is_local_process_zero and not self.in_training:
124
            self.cur_steps = 0
125
            self.max_steps = 0
126

127
    def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
128
        r"""
129
        Event called after logging the last logs.
130
        """
131
        if not state.is_local_process_zero:
132
            return
133

134
        logs = dict(
135
            current_steps=self.cur_steps,
136
            total_steps=self.max_steps,
137
            loss=state.log_history[-1].get("loss", None),
138
            eval_loss=state.log_history[-1].get("eval_loss", None),
139
            predict_loss=state.log_history[-1].get("predict_loss", None),
140
            reward=state.log_history[-1].get("reward", None),
141
            learning_rate=state.log_history[-1].get("learning_rate", None),
142
            epoch=state.log_history[-1].get("epoch", None),
143
            percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
144
            elapsed_time=self.elapsed_time,
145
            remaining_time=self.remaining_time
146
        )
147
        if self.runner is not None:
148
            logger.info("{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
149
                logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
150
            ))
151

152
        os.makedirs(args.output_dir, exist_ok=True)
153
        with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
154
            f.write(json.dumps(logs) + "\n")
155

156
    def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
157
        r"""
158
        Event called after a prediction step.
159
        """
160
        eval_dataloader = kwargs.pop("eval_dataloader", None)
161
        if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training:
162
            if self.max_steps == 0:
163
                self.max_steps = len(eval_dataloader)
164
            self.cur_steps += 1
165
            self.timing()
166

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.