6
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
7
from transformers import Seq2SeqTrainer
9
from llmtuner.extras.constants import IGNORE_INDEX
10
from llmtuner.extras.logging import get_logger
13
from transformers.trainer import PredictionOutput
16
logger = get_logger(__name__)
19
class CustomSeq2SeqTrainer(Seq2SeqTrainer):
21
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
27
inputs: Dict[str, Union[torch.Tensor, Any]],
28
prediction_loss_only: bool,
29
ignore_keys: Optional[List[str]] = None,
30
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
32
Removes the prompt part in the generated tokens.
34
Subclass and override to inject custom behavior.
36
labels = inputs["labels"].detach().clone() if "labels" in inputs else None
37
if self.args.predict_with_generate:
38
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
39
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
40
if prompt_len > label_len:
41
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
42
if label_len > prompt_len:
43
inputs["labels"] = inputs["labels"][:, :prompt_len]
45
loss, generated_tokens, _ = super().prediction_step(
46
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
48
if generated_tokens is not None and self.args.predict_with_generate:
49
generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id
50
generated_tokens = generated_tokens.contiguous()
52
return loss, generated_tokens, labels
54
def _pad_tensors_to_target_len(
56
src_tensor: torch.Tensor,
57
tgt_tensor: torch.Tensor
60
Pads the tensor to the same length as the target tensor.
62
assert self.tokenizer.pad_token_id is not None, "Pad token is required."
63
padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor)
64
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor
65
return padded_tensor.contiguous()
69
predict_results: "PredictionOutput"
72
Saves model predictions to `output_dir`.
74
A custom behavior that not contained in Seq2SeqTrainer.
76
if not self.is_world_process_zero():
79
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
80
logger.info(f"Saving prediction results to {output_prediction_file}")
82
labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
83
preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
85
for i in range(len(preds)):
86
pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0]
88
preds[i] = np.concatenate((preds[i][pad_len[0]:], preds[i][:pad_len[0]]), axis=-1)
90
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=False)
91
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
93
with open(output_prediction_file, "w", encoding="utf-8") as writer:
95
for label, pred in zip(decoded_labels, decoded_preds):
96
res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
97
writer.write("\n".join(res))