colossalai
123 строки · 5.0 Кб
1from typing import Callable, Optional2
3import torch4import tqdm5from torch.optim import Optimizer6from torch.optim.lr_scheduler import _LRScheduler7from torch.utils.data import DataLoader8
9from .base import SLTrainer10from .strategies import Strategy11from .utils import is_rank_012
13
14class RewardModelTrainer(SLTrainer):15"""16Trainer to use while training reward model.
17
18Args:
19model (torch.nn.Module): the model to train
20strategy (Strategy): the strategy to use for training
21optim (Optimizer): the optimizer to use for training
22lr_scheduler (_LRScheduler): the lr scheduler to use for training
23loss_fn (callable): the loss function to use for training
24max_epochs (int, defaults to 2): the number of epochs to train
25"""
26
27def __init__(28self,29model,30strategy: Strategy,31optim: Optimizer,32lr_scheduler: _LRScheduler,33loss_fn: Callable,34max_epochs: int = 1,35) -> None:36super().__init__(strategy, max_epochs, model, optim)37
38self.loss_fn = loss_fn39self.scheduler = lr_scheduler40
41self.num_train_step = 042
43def _eval(self, epoch):44if self.eval_dataloader is not None:45self.model.eval()46dist, num_correct, num_samples = 0, 0, 047with torch.no_grad():48for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader:49chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())50c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())51reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())52r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())53chosen_reward = self.model(chosen_ids, attention_mask=c_mask)54reject_reward = self.model(reject_ids, attention_mask=r_mask)55num_samples += chosen_ids.size(0)56num_correct += (chosen_reward > reject_reward).sum().item()57dist += (chosen_reward - reject_reward).mean().item()58self.dist = dist / len(self.eval_dataloader)59self.acc = num_correct / num_samples60
61if self.writer:62self.writer.add_scalar("eval/dist", self.dist, epoch)63self.writer.add_scalar("eval/acc", self.acc, epoch)64
65def _train(self, epoch):66self.model.train()67step_bar = tqdm.trange(68len(self.train_dataloader), desc=f"Epoch {epoch + 1}/{self.max_epochs}", disable=not is_rank_0()69)70for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:71chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())72c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())73reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())74r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())75chosen_reward = self.model(chosen_ids, attention_mask=c_mask)76reject_reward = self.model(reject_ids, attention_mask=r_mask)77loss = self.loss_fn(chosen_reward, reject_reward)78self.strategy.backward(loss, self.model, self.optimizer)79self.strategy.optimizer_step(self.optimizer)80self.optimizer.zero_grad()81if self.writer:82self.writer.add_scalar("train/loss", loss.item(), self.num_train_step)83self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)84self.writer.add_scalar("train/dist", (chosen_reward - reject_reward).mean().item(), self.num_train_step)85self.writer.add_scalar(86"train/acc", (chosen_reward > reject_reward).float().mean().item(), self.num_train_step87)88self.num_train_step += 189if self.num_train_step % 100 == 0:90self.scheduler.step()91step_bar.update()92step_bar.close()93
94def _before_fit(95self,96train_dataloader: DataLoader,97eval_dataloader: DataLoader,98log_dir: Optional[str] = None,99use_wandb: bool = False,100):101"""102Args:
103train_dataloader (DataLoader): the dataloader to use for training
104eval_dataloader (DataLoader): the dataloader to use for evaluation
105"""
106self.train_dataloader = train_dataloader107self.eval_dataloader = eval_dataloader108
109self.writer = None110if use_wandb and is_rank_0():111assert log_dir is not None, "log_dir must be provided when use_wandb is True"112import wandb113
114wandb.init(project="Coati-rm", sync_tensorboard=True)115if log_dir is not None and is_rank_0():116import os117import time118
119from torch.utils.tensorboard import SummaryWriter120
121log_dir = os.path.join(log_dir, "rm")122log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))123self.writer = SummaryWriter(log_dir=log_dir)124