aurora

Форк
0
/
trainer.py 
103 строки · 4.3 Кб
1
import os
2
import json
3
import torch
4
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
5
from transformers import Trainer
6

7
from llmtuner.extras.logging import get_logger
8

9
if TYPE_CHECKING:
10
    from transformers.trainer import PredictionOutput
11
    from transformers.modeling_utils import PreTrainedModel
12

13

14
logger = get_logger(__name__)
15

16

17
class PairwiseTrainer(Trainer):
18
    r"""
19
    Inherits PeftTrainer to compute pairwise loss.
20
    """
21

22
    def __init__(self, *args, **kwargs):
23
        super().__init__(*args, **kwargs)
24
        self.can_return_loss = True # override property to return eval_loss
25

26
    def compute_loss(
27
        self,
28
        model: "PreTrainedModel",
29
        inputs: Dict[str, torch.Tensor],
30
        return_outputs: Optional[bool] = False
31
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
32
        r"""
33
        Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
34

35
        Subclass and override to inject custom behavior.
36

37
        Note that the first element will be removed from the output tuple.
38
        See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
39
        """
40
        # Compute rewards
41
        _, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
42

43
        unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
44
        if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
45
            values = torch.transpose(values, 0, 1)
46

47
        # Split the inputs and rewards into two parts, chosen and rejected
48
        batch_size = inputs["input_ids"].size(0) // 2
49
        chosen_input_ids, rejected_input_ids = inputs["input_ids"][:batch_size], inputs["input_ids"][batch_size:]
50
        chosen_rewards, rejected_rewards = values[:batch_size], values[batch_size:]
51
        chosen_scores, rejected_scores = [], []
52

53
        # Compute pairwise loss. Only backprop on the different tokens before padding
54
        # Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py
55
        loss = 0
56
        for i in range(batch_size):
57
            chosen_length = (chosen_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
58
            rejected_length = (rejected_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
59
            check_divergence = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero()
60

61
            if len(check_divergence) == 0:
62
                end_index = chosen_length
63
                div_index = end_index - 1
64
            else:
65
                end_index = max(chosen_length, rejected_length)
66
                div_index = check_divergence[0]
67

68
            assert div_index > 0
69
            chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
70
            rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
71
            if return_outputs: # use the score on the last token except pad token for inference
72
                chosen_scores.append(chosen_rewards[i, chosen_length-1])
73
                rejected_scores.append(rejected_rewards[i, rejected_length-1])
74
            loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
75

76
        loss = loss / batch_size
77
        if return_outputs:
78
            chosen_scores, rejected_scores = torch.stack(chosen_scores), torch.stack(rejected_scores)
79
            return loss, [loss, chosen_scores, rejected_scores]
80

81
        return loss
82

83
    def save_predictions(
84
        self,
85
        predict_results: "PredictionOutput"
86
    ) -> None:
87
        r"""
88
        Saves model predictions to `output_dir`.
89

90
        A custom behavior that not contained in Seq2SeqTrainer.
91
        """
92
        if not self.is_world_process_zero():
93
            return
94

95
        output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
96
        logger.info(f"Saving prediction results to {output_prediction_file}")
97
        chosen_scores, rejected_scores = predict_results.predictions
98

99
        with open(output_prediction_file, "w", encoding="utf-8") as writer:
100
            res: List[str] = []
101
            for c_score, r_score in zip(chosen_scores, rejected_scores):
102
                res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))
103
            writer.write("\n".join(res))
104

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.