colossalai

Форк
0
97 строк · 2.7 Кб
1
from typing import Optional
2

3
import torch
4
import torch.nn as nn
5

6
from .utils import masked_mean
7

8

9
class GPTLMLoss(nn.Module):
10
    """
11
    GPT Language Model Loss
12
    """
13

14
    def __init__(self):
15
        super().__init__()
16
        # NOTE: default ignore_index is -100, which is equal to IGNORE_INDEX in sft_dataset.py
17
        self.loss = nn.CrossEntropyLoss()
18

19
    def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
20
        shift_logits = logits[..., :-1, :].contiguous()
21
        shift_labels = labels[..., 1:].contiguous()
22
        # Flatten the tokens
23
        return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
24

25

26
class PolicyLoss(nn.Module):
27
    """
28
    Policy Loss for PPO
29
    """
30

31
    def __init__(self, clip_eps: float = 0.2) -> None:
32
        super().__init__()
33
        self.clip_eps = clip_eps
34

35
    def forward(
36
        self,
37
        log_probs: torch.Tensor,
38
        old_log_probs: torch.Tensor,
39
        advantages: torch.Tensor,
40
        action_mask: Optional[torch.Tensor] = None,
41
    ) -> torch.Tensor:
42
        ratio = (log_probs - old_log_probs).exp()
43
        surr1 = ratio * advantages
44
        surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
45
        loss = -torch.min(surr1, surr2)
46
        if action_mask is not None:
47
            loss = masked_mean(loss, action_mask)
48
        loss = loss.mean()
49
        return loss
50

51

52
class ValueLoss(nn.Module):
53
    """
54
    Value Loss for PPO
55
    """
56

57
    def __init__(self, clip_eps: float = 0.4) -> None:
58
        super().__init__()
59
        self.clip_eps = clip_eps
60

61
    def forward(
62
        self,
63
        values: torch.Tensor,
64
        old_values: torch.Tensor,
65
        reward: torch.Tensor,
66
        action_mask: Optional[torch.Tensor] = None,
67
    ) -> torch.Tensor:
68
        values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
69
        surr1 = (values_clipped - reward) ** 2
70
        surr2 = (values - reward) ** 2
71
        loss = torch.max(surr1, surr2)
72
        loss = loss.mean()
73
        return 0.5 * loss
74

75

76
class LogSigLoss(nn.Module):
77
    """
78
    Pairwise Loss for Reward Model
79
    Details: https://arxiv.org/abs/2203.02155
80
    """
81

82
    def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
83
        probs = torch.sigmoid(chosen_reward - reject_reward)
84
        log_probs = torch.log(probs)
85
        loss = -log_probs.mean()
86
        return loss
87

88

89
class LogExpLoss(nn.Module):
90
    """
91
    Pairwise Loss for Reward Model
92
    Details: https://arxiv.org/abs/2204.05862
93
    """
94

95
    def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
96
        loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean()
97
        return loss
98

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.