aurora

Форк
0
/
trainer.py 
372 строки · 15.9 Кб
1
import os
2
import sys
3
import math
4
import torch
5
from tqdm import tqdm
6
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
7

8
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
9
from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
10
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
11

12
from trl import PPOTrainer
13
from trl.core import PPODecorators, logprobs_from_logits
14

15
from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback
16
from llmtuner.extras.logging import get_logger
17
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
18
from llmtuner.train.ppo.utils import dump_layernorm, get_rewards_from_server, restore_layernorm, replace_model
19

20
if TYPE_CHECKING:
21
    from transformers import Seq2SeqTrainingArguments, TrainerCallback
22
    from trl import AutoModelForCausalLMWithValueHead
23
    from llmtuner.hparams import ModelArguments, FinetuningArguments, GeneratingArguments
24

25

26
logger = get_logger(__name__)
27

28

29
class CustomPPOTrainer(PPOTrainer, Trainer):
30
    r"""
31
    Inherits PPOTrainer.
32
    """
33

34
    def __init__(
35
        self,
36
        model_args: "ModelArguments",
37
        training_args: "Seq2SeqTrainingArguments",
38
        finetuning_args: "FinetuningArguments",
39
        generating_args: "GeneratingArguments",
40
        callbacks: List["TrainerCallback"],
41
        reward_model: "AutoModelForCausalLMWithValueHead",
42
        **kwargs
43
    ):
44
        PPOTrainer.__init__(self, **kwargs)
45

46
        self.args = training_args
47
        self.model_args = model_args
48
        self.finetuning_args = finetuning_args
49
        self.reward_model = reward_model
50

51
        self.generation_config = GenerationConfig(
52
            pad_token_id=self.tokenizer.pad_token_id,
53
            eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
54
            **generating_args.to_dict()
55
        )
56

57
        self.state = TrainerState()
58
        self.control = TrainerControl()
59
        self.is_deepspeed_enabled = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(
60
            self.accelerator.state, "deepspeed_plugin"
61
        )
62
        self.log_callback, self.save_callback = callbacks[0], callbacks[1]
63
        assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)
64

65
        if self.args.max_steps > 0:
66
            logger.info("max_steps is given, it will override any value given in num_train_epochs")
67

68
        if finetuning_args.reward_model_type == "full":
69
            if self.is_deepspeed_enabled:
70
                if not (
71
                    getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
72
                    or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False)
73
                ): # quantized models are already set on the correct device
74
                    self.reward_model = self._prepare_deepspeed(self.reward_model)
75
            else:
76
                self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
77

78
    def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
79
        r"""
80
        Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
81
        """
82
        if resume_from_checkpoint is not None:
83
            raise ValueError("`resume_from_checkpoint` will be supported in the future version.")
84

85
        total_train_batch_size = (
86
            self.args.per_device_train_batch_size
87
            * self.args.gradient_accumulation_steps
88
            * self.finetuning_args.ppo_buffer_size
89
            * self.args.world_size
90
        )
91
        if self.args.max_steps > 0:
92
            num_examples = total_train_batch_size * self.args.max_steps
93
            num_train_epochs = sys.maxsize
94
            max_steps = self.args.max_steps
95
            steps_in_epoch = self.args.max_steps
96
        else:
97
            len_dataloader = len(self.dataloader)
98
            num_examples = len(self.dataset)
99
            num_train_epochs = self.args.num_train_epochs
100
            max_steps = math.ceil(num_train_epochs * len_dataloader)
101
            steps_in_epoch = len_dataloader
102

103
        self.state.max_steps = max_steps
104
        self.state.num_train_epochs = num_train_epochs
105
        self.state.is_local_process_zero = self.is_local_process_zero()
106
        self.state.is_world_process_zero = self.is_world_process_zero()
107

108
        if self.is_world_process_zero():
109
            logger.info("***** Running training *****")
110
            logger.info("  Num examples = {}".format(num_examples))
111
            logger.info("  Num Epochs = {}".format(num_train_epochs))
112
            logger.info("  Instantaneous batch size per device = {}".format(self.args.per_device_train_batch_size))
113
            logger.info("  Total train batch size (w. parallel, buffer, distributed & accumulation) = {}".format(
114
                total_train_batch_size
115
            ))
116
            logger.info("  Gradient Accumulation steps = {}".format(self.args.gradient_accumulation_steps))
117
            logger.info("  Num optimization epochs per batch = {}".format(self.finetuning_args.ppo_epochs))
118
            logger.info("  Total training steps = {}".format(max_steps))
119
            logger.info("  Number of trainable parameters = {}".format(count_parameters(self.model)[0]))
120

121
        unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
122
        dataiter = iter(self.dataloader)
123
        loss_meter = AverageMeter()
124
        reward_meter = AverageMeter()
125
        self.log_callback.on_train_begin(self.args, self.state, self.control)
126

127
        for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
128
            try:
129
                batch = next(dataiter)
130
            except StopIteration:
131
                dataiter = iter(self.dataloader)
132
                batch = next(dataiter)
133

134
            # Cast to inference mode
135
            unwrapped_model.gradient_checkpointing_disable()
136
            unwrapped_model.config.use_cache = True
137
            self.model.eval()
138

139
            # Get inputs
140
            self.tokenizer.padding_side = "right" # change padding side
141
            queries, responses, rewards = [], [], []
142
            for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
143
                mini_batch_queries, mini_batch_responses = self.get_inputs(batch[idx:idx+self.config.mini_batch_size])
144
                mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model)
145
                queries.extend(mini_batch_queries)
146
                responses.extend(mini_batch_responses)
147
                rewards.extend(mini_batch_rewards)
148

149
            # Cast to training mode
150
            unwrapped_model.gradient_checkpointing_enable()
151
            unwrapped_model.config.use_cache = False
152
            self.model.train()
153

154
            # Run PPO step
155
            stats = self.step(queries, responses, rewards)
156
            self.tokenizer.padding_side = "left" # restore padding side
157
            loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards))
158
            reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
159

160
            if self.config.log_with is not None:
161
                try:
162
                    batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True)
163
                    batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
164
                    self.log_stats(stats, batch, rewards)
165
                except:
166
                    logger.warning("Failed to save stats due to unknown errors.")
167

168
            self.state.global_step += 1
169
            self.log_callback.on_step_end(self.args, self.state, self.control)
170

171
            if self.is_local_process_zero() and (step+1) % self.args.logging_steps == 0:
172
                logs = dict(
173
                    loss=round(loss_meter.avg, 4),
174
                    reward=round(reward_meter.avg, 4),
175
                    learning_rate=stats["ppo/learning_rate"],
176
                    epoch=round(step / steps_in_epoch, 2)
177
                )
178
                tqdm.write(str(logs))
179
                logs["step"] = step
180
                self.state.log_history.append(logs)
181
                self.log_callback.on_log(self.args, self.state, self.control)
182
                loss_meter.reset()
183
                reward_meter.reset()
184

185
            if (step+1) % self.args.save_steps == 0: # save checkpoint
186
                self.save_model(os.path.join(
187
                    self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step)
188
                ))
189
                self.save_callback.on_save(
190
                    self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
191
                )
192

193
            if self.control.should_epoch_stop or self.control.should_training_stop:
194
                break
195

196
        self.log_callback.on_train_end(self.args, self.state, self.control)
197
        self.save_callback.on_train_end(
198
            self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
199
        )
200

201
    @torch.no_grad()
202
    def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
203
        r"""
204
        Generates model's responses given queries.
205
        """
206
        if self.finetuning_args.upcast_layernorm:
207
            layernorm_params = dump_layernorm(self.model)
208

209
        unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
210
        generate_output: torch.Tensor = unwrapped_model.generate(
211
            generation_config=self.generation_config,
212
            logits_processor=get_logits_processor(),
213
            **batch
214
        )
215

216
        if self.finetuning_args.upcast_layernorm:
217
            restore_layernorm(self.model, layernorm_params)
218

219
        query = batch["input_ids"].detach().cpu()
220
        response = generate_output[:, batch["input_ids"].size(-1):].detach().cpu()
221
        queries, responses = [], []
222
        for i in range(len(query)):
223
            query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
224
            response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
225

226
            if len(response_index) == 0:
227
                response_length = 1 # allow empty response
228
            else:
229
                response_length = response_index[-1].item() + 1
230

231
            queries.append(query[i, query_length:]) # remove padding from left
232
            responses.append(response[i, :response_length]) # remove padding from right
233

234
        return queries, responses
235

236
    @torch.no_grad()
237
    def get_rewards(
238
        self,
239
        queries: List[torch.Tensor],
240
        responses: List[torch.Tensor],
241
        unwrapped_model: "AutoModelForCausalLMWithValueHead"
242
    ) -> List[torch.Tensor]:
243
        r"""
244
        Computes scores using given reward model.
245

246
        Both inputs and outputs are put on CPU.
247
        """
248
        if self.finetuning_args.reward_model_type == "api":
249
            token_ids = [torch.cat((q, r), dim=-1).tolist() for q, r in zip(queries, responses)]
250
            messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
251
            return get_rewards_from_server(self.reward_model, messages)
252

253
        if self.finetuning_args.reward_model_type == "lora":
254
            replace_model(unwrapped_model, target="reward")
255
            reward_model = self.model
256
        else:
257
            reward_model = self.reward_model
258

259
        batch = self.prepare_model_inputs(queries, responses)
260

261
        with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
262
            _, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
263

264
        if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture
265
            values = torch.transpose(values, 0, 1)
266

267
        rewards = []
268
        for i in range(values.size(0)):
269
            end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero()
270
            end_index = end_indexes[-1].item() if len(end_indexes) else 0
271
            rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
272

273
        if self.finetuning_args.reward_model_type == "lora":
274
            replace_model(unwrapped_model, target="default")
275

276
        return rewards
277

278
    @PPODecorators.empty_device_cache()
279
    def batched_forward_pass(
280
        self,
281
        model: "AutoModelForCausalLMWithValueHead",
282
        queries: torch.Tensor,
283
        responses: torch.Tensor,
284
        model_inputs: dict,
285
        return_logits: Optional[bool] = False,
286
        response_masks: Optional[torch.Tensor] = None
287
    ):
288
        r"""
289
        Calculates model outputs in multiple batches.
290

291
        Subclass and override to inject custom behavior.
292
        """
293
        bs = len(queries)
294
        fbs = self.config.mini_batch_size
295
        all_logprobs = []
296
        all_logits = []
297
        all_masks = []
298
        all_values = []
299

300
        for i in range(math.ceil(bs / fbs)):
301
            input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
302
            query_batch = queries[i * fbs : (i + 1) * fbs]
303
            response_batch = responses[i * fbs : (i + 1) * fbs]
304
            if response_masks is not None:
305
                response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
306
            input_ids = input_kwargs["input_ids"]
307
            attention_mask = input_kwargs["attention_mask"]
308

309
            with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
310
                logits, _, values = model(**input_kwargs)
311

312
            unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
313
            if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
314
                values = torch.transpose(values, 0, 1)
315

316
            logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
317
            masks = torch.zeros_like(attention_mask)
318
            masks[:, :-1] = attention_mask[:, 1:]
319

320
            for j in range(len(query_batch)):
321
                start = len(query_batch[j]) - 1
322
                if attention_mask[j, 0] == 0: # offset left padding
323
                    start += attention_mask[j, :].nonzero()[0].item()
324
                end = start + len(response_batch[j])
325

326
                if response_masks is not None:
327
                    response_masks_batch = torch.cat(
328
                        (torch.zeros_like(query_batch[j]), response_masks_batch[j])
329
                    )[1:]
330

331
                masks[j, :start] = 0
332
                masks[j, end:] = 0
333
                if response_masks is not None:
334
                    masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
335

336
            if return_logits:
337
                all_logits.append(logits)
338
            else:
339
                del logits
340

341
            all_values.append(values)
342
            all_logprobs.append(logprobs)
343
            all_masks.append(masks)
344

345
        return (
346
            torch.cat(all_logprobs),
347
            torch.cat(all_logits)[:, :-1] if return_logits else None,
348
            torch.cat(all_values)[:, :-1],
349
            torch.cat(all_masks)[:, :-1],
350
        )
351

352
    def save_model(self, output_dir: Optional[str] = None) -> None:
353
        r"""
354
        Saves model checkpoint.
355

356
        Subclass and override to inject custom behavior.
357
        """
358
        if self.args.should_save:
359
            try:
360
                self._save(output_dir, state_dict=self.accelerator.get_state_dict(self.model))
361
            except ValueError:
362
                logger.warning(
363
                    " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead,"
364
                    " use zero_to_fp32.py to recover weights"
365
                )
366
                self._save(output_dir, state_dict={})
367
                for filename in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]: # remove dummy checkpoint
368
                    file = os.path.join(output_dir, filename)
369
                    if os.path.isfile(file):
370
                        os.remove(file)
371

372
                self.model.save_checkpoint(output_dir) # wrapped model
373

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.