4
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
5
from transformers import Trainer
7
from llmtuner.extras.logging import get_logger
10
from transformers.trainer import PredictionOutput
11
from transformers.modeling_utils import PreTrainedModel
14
logger = get_logger(__name__)
17
class PairwiseTrainer(Trainer):
19
Inherits PeftTrainer to compute pairwise loss.
22
def __init__(self, *args, **kwargs):
23
super().__init__(*args, **kwargs)
24
self.can_return_loss = True
28
model: "PreTrainedModel",
29
inputs: Dict[str, torch.Tensor],
30
return_outputs: Optional[bool] = False
31
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
33
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
35
Subclass and override to inject custom behavior.
37
Note that the first element will be removed from the output tuple.
38
See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
41
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
43
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
44
if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
45
values = torch.transpose(values, 0, 1)
48
batch_size = inputs["input_ids"].size(0) // 2
49
chosen_input_ids, rejected_input_ids = inputs["input_ids"][:batch_size], inputs["input_ids"][batch_size:]
50
chosen_rewards, rejected_rewards = values[:batch_size], values[batch_size:]
51
chosen_scores, rejected_scores = [], []
56
for i in range(batch_size):
57
chosen_length = (chosen_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
58
rejected_length = (rejected_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
59
check_divergence = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero()
61
if len(check_divergence) == 0:
62
end_index = chosen_length
63
div_index = end_index - 1
65
end_index = max(chosen_length, rejected_length)
66
div_index = check_divergence[0]
69
chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
70
rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
72
chosen_scores.append(chosen_rewards[i, chosen_length-1])
73
rejected_scores.append(rejected_rewards[i, rejected_length-1])
74
loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
76
loss = loss / batch_size
78
chosen_scores, rejected_scores = torch.stack(chosen_scores), torch.stack(rejected_scores)
79
return loss, [loss, chosen_scores, rejected_scores]
85
predict_results: "PredictionOutput"
88
Saves model predictions to `output_dir`.
90
A custom behavior that not contained in Seq2SeqTrainer.
92
if not self.is_world_process_zero():
95
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
96
logger.info(f"Saving prediction results to {output_prediction_file}")
97
chosen_scores, rejected_scores = predict_results.predictions
99
with open(output_prediction_file, "w", encoding="utf-8") as writer:
101
for c_score, r_score in zip(chosen_scores, rejected_scores):
102
res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))
103
writer.write("\n".join(res))