colossalai
179 строк · 6.8 Кб
1import os
2from abc import ABC, abstractmethod
3from typing import Any, Dict, List
4
5import ray
6import torch
7from coati.experience_buffer.utils import BufferItem
8from coati.experience_maker import Experience
9from torch.utils.data import DataLoader
10from tqdm import tqdm
11
12from .callbacks import TrainerCallback
13from .detached_replay_buffer import DetachedReplayBuffer
14from .utils import is_rank_0
15
16
17class DetachedTrainer(ABC):
18"""
19Base class for detached rlhf trainers.
20'detach' means that the experience maker is detached compared to a normal Trainer.
21Please set name attribute during init:
22>>> trainer = DetachedTrainer.options(..., name = "xxx", ...).remote()
23So an ExperienceMakerHolder can reach the detached_replay_buffer by Actor's name.
24Args:
25detached_strategy (DetachedStrategy): the strategy to use for training
26detached_replay_buffer_ref (ObjectRef[DetachedReplayBuffer]): the replay buffer to use for training
27data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader
28callbacks (List[Callback], defaults to []): the callbacks to call during training process
29generate_kwargs (dict, optional): the kwargs to use while model generating
30
31"""
32
33def __init__(
34self,
35experience_maker_holder_name_list: List[str],
36train_batch_size: int = 8,
37buffer_limit: int = 0,
38dataloader_pin_memory: bool = True,
39callbacks: List[TrainerCallback] = [],
40debug: bool = False,
41) -> None:
42super().__init__()
43self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit)
44self.dataloader_pin_memory = dataloader_pin_memory
45self.callbacks = callbacks
46self.target_holder_name_list = experience_maker_holder_name_list
47self.target_holder_list = []
48self._is_target_holder_initialized = False
49self._debug = debug
50
51def update_target_holder_list(self):
52# as the length of target_holder_list may be zero, we need to check it by a bool flag
53if not self._is_target_holder_initialized:
54for name in self.target_holder_name_list:
55self.target_holder_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))
56self._is_target_holder_initialized = True
57
58@abstractmethod
59def _update_remote_makers(self, fully_update: bool = False, **kwargs):
60pass
61
62def sync_models_to_remote_makers(self, **kwargs):
63self._update_remote_makers(fully_update=True, **kwargs)
64
65@abstractmethod
66def training_step(self, experience: Experience) -> Dict[str, Any]:
67pass
68
69def _learn(self, update_steps: int, train_epochs: int) -> None:
70data = []
71# warmup
72pbar = tqdm(range(update_steps), desc=f"Train epoch [1/{train_epochs}]", disable=not is_rank_0())
73self._on_epoch_start(0)
74self._learn_epoch(pbar, data)
75self._on_epoch_end(0)
76# item is already a batch
77dataloader = DataLoader(
78data, batch_size=1, shuffle=True, pin_memory=self.dataloader_pin_memory, collate_fn=lambda x: x[0]
79)
80for epoch in range(1, train_epochs):
81pbar = tqdm(dataloader, desc=f"Train epoch [{epoch + 1}/{train_epochs}]", disable=not is_rank_0())
82self._on_epoch_start(epoch)
83self._learn_epoch(pbar, data)
84self._on_epoch_end(epoch)
85
86def _learn_epoch(self, pbar: tqdm, data: List[Experience]) -> None:
87is_warmup = len(data) == 0
88for x in pbar:
89if self._debug:
90print("[trainer] training step")
91# sample a batch and then train to avoid waiting
92experience = x if not is_warmup else self._buffer_sample()
93experience.to_device(torch.cuda.current_device())
94self._on_batch_start()
95metrics = self.training_step(experience)
96self._on_batch_end(metrics, experience)
97
98if self._debug:
99print("[trainer] step over")
100experience.to_device("cpu")
101if is_warmup:
102data.append(experience)
103pbar.set_postfix(metrics)
104
105def fit(self, total_steps: int, update_steps: int, train_epochs: int = 1) -> None:
106self._on_fit_start()
107for i in tqdm(range(total_steps // update_steps), desc="Trainer", disable=not is_rank_0()):
108self._on_episode_start(i)
109self._learn(update_steps, train_epochs)
110self._on_update_start()
111self._update_remote_makers()
112self._on_update_end()
113self._on_episode_end(i)
114self._on_fit_end()
115
116@ray.method(concurrency_group="buffer_length")
117def buffer_get_length(self):
118# called by ExperienceMakerHolder
119if self._debug:
120print("[trainer] telling length")
121return self.detached_replay_buffer.get_length()
122
123@ray.method(concurrency_group="buffer_append")
124def buffer_append(self, experience: Experience):
125# called by ExperienceMakerHolder
126if self._debug:
127print(f"[trainer] receiving exp.")
128self.detached_replay_buffer.append(experience)
129
130@ray.method(concurrency_group="buffer_append")
131def buffer_extend(self, items: List[BufferItem]):
132# called by ExperienceMakerHolder
133if self._debug:
134print(f"[trainer] receiving exp.")
135self.detached_replay_buffer.extend(items)
136
137@ray.method(concurrency_group="buffer_sample")
138def _buffer_sample(self):
139return self.detached_replay_buffer.sample()
140
141def _on_fit_start(self) -> None:
142for callback in self.callbacks:
143callback.on_fit_start()
144
145def _on_fit_end(self) -> None:
146for callback in self.callbacks:
147callback.on_fit_end()
148
149def _on_episode_start(self, episode: int) -> None:
150for callback in self.callbacks:
151callback.on_episode_start(episode)
152
153def _on_episode_end(self, episode: int) -> None:
154for callback in self.callbacks:
155callback.on_episode_end(episode)
156
157def _on_epoch_start(self, epoch: int) -> None:
158for callback in self.callbacks:
159callback.on_epoch_start(epoch)
160
161def _on_epoch_end(self, epoch: int) -> None:
162for callback in self.callbacks:
163callback.on_epoch_end(epoch)
164
165def _on_batch_start(self) -> None:
166for callback in self.callbacks:
167callback.on_batch_start()
168
169def _on_batch_end(self, metrics: dict, experience: Experience) -> None:
170for callback in self.callbacks:
171callback.on_batch_end(metrics, experience)
172
173def _on_update_start(self) -> None:
174for callback in self.callbacks:
175callback.on_update_start()
176
177def _on_update_end(self) -> None:
178for callback in self.callbacks:
179callback.on_update_end()
180