colossalai
97 строк · 2.7 Кб
1from typing import Optional2
3import torch4import torch.nn as nn5
6from .utils import masked_mean7
8
9class GPTLMLoss(nn.Module):10"""11GPT Language Model Loss
12"""
13
14def __init__(self):15super().__init__()16# NOTE: default ignore_index is -100, which is equal to IGNORE_INDEX in sft_dataset.py17self.loss = nn.CrossEntropyLoss()18
19def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:20shift_logits = logits[..., :-1, :].contiguous()21shift_labels = labels[..., 1:].contiguous()22# Flatten the tokens23return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))24
25
26class PolicyLoss(nn.Module):27"""28Policy Loss for PPO
29"""
30
31def __init__(self, clip_eps: float = 0.2) -> None:32super().__init__()33self.clip_eps = clip_eps34
35def forward(36self,37log_probs: torch.Tensor,38old_log_probs: torch.Tensor,39advantages: torch.Tensor,40action_mask: Optional[torch.Tensor] = None,41) -> torch.Tensor:42ratio = (log_probs - old_log_probs).exp()43surr1 = ratio * advantages44surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages45loss = -torch.min(surr1, surr2)46if action_mask is not None:47loss = masked_mean(loss, action_mask)48loss = loss.mean()49return loss50
51
52class ValueLoss(nn.Module):53"""54Value Loss for PPO
55"""
56
57def __init__(self, clip_eps: float = 0.4) -> None:58super().__init__()59self.clip_eps = clip_eps60
61def forward(62self,63values: torch.Tensor,64old_values: torch.Tensor,65reward: torch.Tensor,66action_mask: Optional[torch.Tensor] = None,67) -> torch.Tensor:68values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)69surr1 = (values_clipped - reward) ** 270surr2 = (values - reward) ** 271loss = torch.max(surr1, surr2)72loss = loss.mean()73return 0.5 * loss74
75
76class LogSigLoss(nn.Module):77"""78Pairwise Loss for Reward Model
79Details: https://arxiv.org/abs/2203.02155
80"""
81
82def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:83probs = torch.sigmoid(chosen_reward - reject_reward)84log_probs = torch.log(probs)85loss = -log_probs.mean()86return loss87
88
89class LogExpLoss(nn.Module):90"""91Pairwise Loss for Reward Model
92Details: https://arxiv.org/abs/2204.05862
93"""
94
95def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:96loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean()97return loss98