colossalai
202 строки · 9.1 Кб
1from typing import Dict, List, Optional2
3from coati.experience_buffer import NaiveExperienceBuffer4from coati.experience_maker import Experience, NaiveExperienceMaker5from coati.models.base import Actor, Critic, RewardModel, get_base_model6from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss7from coati.models.utils import calc_action_log_probs8from torch.optim import Optimizer9from torch.utils.data import DataLoader, DistributedSampler10from tqdm import tqdm11from transformers import PreTrainedTokenizerBase12
13from colossalai.accelerator import get_accelerator14
15from .base import OnPolicyTrainer16from .callbacks import Callback17from .strategies import GeminiStrategy, Strategy18from .utils import CycledDataLoader, is_rank_0, to_device19
20
21def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict:22unwrapped_model = strategy.unwrap_model(actor)23hf_model = get_base_model(unwrapped_model)24new_kwargs = {**generate_kwargs}25# use huggingface models method directly26if "prepare_inputs_fn" not in generate_kwargs and hasattr(hf_model, "prepare_inputs_for_generation"):27new_kwargs["prepare_inputs_fn"] = hf_model.prepare_inputs_for_generation28
29if "update_model_kwargs_fn" not in generate_kwargs and hasattr(hf_model, "_update_model_kwargs_for_generation"):30new_kwargs["update_model_kwargs_fn"] = hf_model._update_model_kwargs_for_generation31
32return new_kwargs33
34
35class PPOTrainer(OnPolicyTrainer):36"""37Trainer for PPO algorithm.
38
39Args:
40strategy (Strategy): the strategy to use for training
41actor (Actor): the actor model in ppo algorithm
42critic (Critic): the critic model in ppo algorithm
43reward_model (RewardModel): the reward model in rlhf algorithm to make reward of sentences
44initial_model (Actor): the initial model in rlhf algorithm to generate reference logics to limit the update of actor
45actor_optim (Optimizer): the optimizer to use for actor model
46critic_optim (Optimizer): the optimizer to use for critic model
47kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss
48train_batch_size (int, defaults to 8): the batch size to use for training
49buffer_limit (int, defaults to 0): the max_size limitation of buffer
50buffer_cpu_offload (bool, defaults to True): whether to offload buffer to cpu
51eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
52vf_coef (float, defaults to 1.0): the coefficient of value loss
53ptx_coef (float, defaults to 0.9): the coefficient of ptx loss
54value_clip (float, defaults to 0.4): the clip coefficient of value loss
55sample_buffer (bool, defaults to False): whether to sample from buffer
56dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
57offload_inference_models (bool, defaults to True): whether to offload inference models to cpu during training process
58callbacks (List[Callback], defaults to []): the callbacks to call during training process
59generate_kwargs (dict, optional): the kwargs to use while model generating
60"""
61
62def __init__(63self,64strategy: Strategy,65actor: Actor,66critic: Critic,67reward_model: RewardModel,68initial_model: Actor,69actor_optim: Optimizer,70critic_optim: Optimizer,71tokenizer: PreTrainedTokenizerBase,72kl_coef: float = 0.1,73ptx_coef: float = 0.9,74train_batch_size: int = 8,75buffer_limit: int = 0,76buffer_cpu_offload: bool = True,77eps_clip: float = 0.2,78vf_coef: float = 1.0,79value_clip: float = 0.4,80sample_buffer: bool = False,81dataloader_pin_memory: bool = True,82offload_inference_models: bool = True,83callbacks: List[Callback] = [],84**generate_kwargs,85) -> None:86if isinstance(strategy, GeminiStrategy):87assert not offload_inference_models, "GeminiPlugin is not compatible with manual model.to('cpu')"88
89data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)90super().__init__(strategy, data_buffer, sample_buffer, dataloader_pin_memory, callbacks)91
92self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)93self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, tokenizer, kl_coef)94
95self.actor = actor96self.critic = critic97self.tokenizer = tokenizer98
99self.actor_loss_fn = PolicyLoss(eps_clip)100self.critic_loss_fn = ValueLoss(value_clip)101self.vf_coef = vf_coef102self.ptx_loss_fn = GPTLMLoss()103self.ptx_coef = ptx_coef104self.actor_optim = actor_optim105self.critic_optim = critic_optim106
107self.offload_inference_models = offload_inference_models108self.device = get_accelerator().get_current_device()109
110def _before_fit(111self,112prompt_dataloader: DataLoader,113pretrain_dataloader: DataLoader,114log_dir: Optional[str] = None,115use_wandb: bool = False,116):117"""118Args:
119prompt_dataloader (DataLoader): the dataloader to use for prompt data
120pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
121"""
122self.prompt_dataloader = CycledDataLoader(prompt_dataloader)123self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader)124
125self.writer = None126if use_wandb and is_rank_0():127assert log_dir is not None, "log_dir must be provided when use_wandb is True"128import wandb129
130wandb.init(project="Coati-ppo", sync_tensorboard=True)131if log_dir is not None and is_rank_0():132import os133import time134
135from torch.utils.tensorboard import SummaryWriter136
137log_dir = os.path.join(log_dir, "ppo")138log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))139self.writer = SummaryWriter(log_dir=log_dir)140
141def _make_experience(self, collect_step: int) -> Experience:142prompts = self.prompt_dataloader.next()143if self.offload_inference_models:144# TODO(ver217): this may be controlled by strategy if they are prepared by strategy145self.experience_maker.initial_model.to(self.device)146self.experience_maker.reward_model.to(self.device)147assert isinstance(prompts, dict), f'Unsupported input type "{type(prompts)}"'148return self.experience_maker.make_experience(**prompts, **self.generate_kwargs)149
150def _training_step(self, experience: Experience):151self.actor.train()152self.critic.train()153# policy loss154num_actions = experience.action_log_probs.size(1)155actor_logits = self.actor(experience.sequences, experience.attention_mask)["logits"]156action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions)157actor_loss = self.actor_loss_fn(158action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask159)160actor_loss = (1 - self.ptx_coef) * actor_loss161self.strategy.backward(actor_loss, self.actor, self.actor_optim)162
163# ptx loss164if self.ptx_coef != 0:165batch = self.pretrain_dataloader.next()166batch = to_device(batch, self.device)167ptx_log_probs = self.actor(batch["input_ids"], batch["attention_mask"])["logits"]168ptx_loss = self.ptx_coef * self.ptx_loss_fn(ptx_log_probs, batch["labels"])169self.strategy.backward(ptx_loss, self.actor, self.actor_optim)170
171self.strategy.optimizer_step(self.actor_optim)172self.actor_optim.zero_grad()173
174# value loss175values = self.critic(experience.sequences, attention_mask=experience.attention_mask)176critic_loss = self.critic_loss_fn(values, experience.values, experience.reward)177critic_loss = critic_loss * self.vf_coef178self.strategy.backward(critic_loss, self.critic, self.critic_optim)179self.strategy.optimizer_step(self.critic_optim)180self.critic_optim.zero_grad()181
182def _learn(self, update_step: int):183if self.offload_inference_models:184self.experience_maker.initial_model.to("cpu")185self.experience_maker.reward_model.to("cpu")186
187# buffer may be empty at first, we should rebuild at each training188if self.sample_buffer:189experience = self.data_buffer.sample()190self._on_learn_batch_start()191experience.to_device(self.device)192self._training_step(experience)193self._on_learn_batch_end(experience)194else:195if isinstance(self.dataloader.sampler, DistributedSampler):196self.dataloader.sampler.set_epoch(update_step)197pbar = tqdm(self.dataloader, desc=f"Train epoch [{update_step + 1}]", disable=not is_rank_0())198for experience in pbar:199self._on_learn_batch_start()200experience.to_device(self.device)201self._training_step(experience)202self._on_learn_batch_end(experience)203