aurora
1import torch
2from dataclasses import dataclass
3from typing import Any, Dict, Sequence
4from transformers import DataCollatorWithPadding
5
6
7@dataclass
8class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
9r"""
10Data collator for pairwise data.
11"""
12
13def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
14r"""
15Pads batched data to the longest sequence in the batch.
16
17We generate 2 * n examples where the first n examples represent chosen examples and
18the last n examples represent rejected examples.
19"""
20features = [
21{
22"input_ids": feature["prompt_ids"] + feature[key],
23"attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key]))
24}
25for key in ("chosen_ids", "rejected_ids") for feature in features
26]
27return super().__call__(features)
28