colossalai

Форк
0
123 строки · 5.0 Кб
1
from typing import Callable, Optional
2

3
import torch
4
import tqdm
5
from torch.optim import Optimizer
6
from torch.optim.lr_scheduler import _LRScheduler
7
from torch.utils.data import DataLoader
8

9
from .base import SLTrainer
10
from .strategies import Strategy
11
from .utils import is_rank_0
12

13

14
class RewardModelTrainer(SLTrainer):
15
    """
16
        Trainer to use while training reward model.
17

18
    Args:
19
        model (torch.nn.Module): the model to train
20
        strategy (Strategy): the strategy to use for training
21
        optim (Optimizer): the optimizer to use for training
22
        lr_scheduler (_LRScheduler): the lr scheduler to use for training
23
        loss_fn (callable): the loss function to use for training
24
        max_epochs (int, defaults to 2): the number of epochs to train
25
    """
26

27
    def __init__(
28
        self,
29
        model,
30
        strategy: Strategy,
31
        optim: Optimizer,
32
        lr_scheduler: _LRScheduler,
33
        loss_fn: Callable,
34
        max_epochs: int = 1,
35
    ) -> None:
36
        super().__init__(strategy, max_epochs, model, optim)
37

38
        self.loss_fn = loss_fn
39
        self.scheduler = lr_scheduler
40

41
        self.num_train_step = 0
42

43
    def _eval(self, epoch):
44
        if self.eval_dataloader is not None:
45
            self.model.eval()
46
            dist, num_correct, num_samples = 0, 0, 0
47
            with torch.no_grad():
48
                for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader:
49
                    chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
50
                    c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
51
                    reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
52
                    r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
53
                    chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
54
                    reject_reward = self.model(reject_ids, attention_mask=r_mask)
55
                    num_samples += chosen_ids.size(0)
56
                    num_correct += (chosen_reward > reject_reward).sum().item()
57
                    dist += (chosen_reward - reject_reward).mean().item()
58
                self.dist = dist / len(self.eval_dataloader)
59
                self.acc = num_correct / num_samples
60

61
            if self.writer:
62
                self.writer.add_scalar("eval/dist", self.dist, epoch)
63
                self.writer.add_scalar("eval/acc", self.acc, epoch)
64

65
    def _train(self, epoch):
66
        self.model.train()
67
        step_bar = tqdm.trange(
68
            len(self.train_dataloader), desc=f"Epoch {epoch + 1}/{self.max_epochs}", disable=not is_rank_0()
69
        )
70
        for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
71
            chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
72
            c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
73
            reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
74
            r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
75
            chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
76
            reject_reward = self.model(reject_ids, attention_mask=r_mask)
77
            loss = self.loss_fn(chosen_reward, reject_reward)
78
            self.strategy.backward(loss, self.model, self.optimizer)
79
            self.strategy.optimizer_step(self.optimizer)
80
            self.optimizer.zero_grad()
81
            if self.writer:
82
                self.writer.add_scalar("train/loss", loss.item(), self.num_train_step)
83
                self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
84
                self.writer.add_scalar("train/dist", (chosen_reward - reject_reward).mean().item(), self.num_train_step)
85
                self.writer.add_scalar(
86
                    "train/acc", (chosen_reward > reject_reward).float().mean().item(), self.num_train_step
87
                )
88
            self.num_train_step += 1
89
            if self.num_train_step % 100 == 0:
90
                self.scheduler.step()
91
            step_bar.update()
92
        step_bar.close()
93

94
    def _before_fit(
95
        self,
96
        train_dataloader: DataLoader,
97
        eval_dataloader: DataLoader,
98
        log_dir: Optional[str] = None,
99
        use_wandb: bool = False,
100
    ):
101
        """
102
        Args:
103
            train_dataloader (DataLoader): the dataloader to use for training
104
            eval_dataloader (DataLoader): the dataloader to use for evaluation
105
        """
106
        self.train_dataloader = train_dataloader
107
        self.eval_dataloader = eval_dataloader
108

109
        self.writer = None
110
        if use_wandb and is_rank_0():
111
            assert log_dir is not None, "log_dir must be provided when use_wandb is True"
112
            import wandb
113

114
            wandb.init(project="Coati-rm", sync_tensorboard=True)
115
        if log_dir is not None and is_rank_0():
116
            import os
117
            import time
118

119
            from torch.utils.tensorboard import SummaryWriter
120

121
            log_dir = os.path.join(log_dir, "rm")
122
            log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
123
            self.writer = SummaryWriter(log_dir=log_dir)
124

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.