colossalai
274 строки · 11.4 Кб
1import os
2import time
3import tracemalloc
4from threading import Lock
5from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
6
7import ray
8import torch
9from coati.experience_buffer.utils import split_experience_batch
10from coati.experience_maker import Experience, NaiveExperienceMaker
11from coati.models.base import Actor, Critic, RewardModel
12from coati.trainer.strategies import Strategy
13from torch import Tensor
14from tqdm import tqdm
15
16from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback
17from .lora_constructor import LoRAConstructor
18from .utils import get_model_numel, get_rank, is_rank_0, set_dist_env, state_dict_to
19
20
21@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
22class ExperienceMakerHolder:
23"""
24Args:
25detached_trainer_name_list: str list to get ray actor handles
26strategy:
27kl_coef: the coefficient of kl divergence loss
28sync_models_from_trainers: whether to sync models from trainers. If True, you must call sync_models_to_remote_makers() in trainers to sync models.
29"""
30
31def __init__(
32self,
33detached_trainer_name_list: List[str],
34strategy_fn: Callable[[], Strategy],
35# a function returns (actor, critic, reward_model, initial_model)
36model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]],
37env_info: Dict[str, str] = None,
38sync_models_from_trainers: bool = False,
39buffer_cpu_offload: bool = True,
40kl_coef: float = 0.1,
41callbacks: List[MakerCallback] = [],
42eval_performance: bool = False,
43debug: bool = False,
44update_lora_weights: bool = False,
45**generate_kwargs,
46):
47# set environment variables
48if env_info:
49set_dist_env(env_info=env_info)
50self.target_trainer_list = []
51assert len(detached_trainer_name_list) > 0
52self._detached_trainer_name_list = detached_trainer_name_list
53self.strategy = strategy_fn()
54self.buffer_cpu_offload = buffer_cpu_offload
55self.kl_coef = kl_coef
56# init models
57with self.strategy.model_init_context():
58actor, critic, reward_model, initial_model = model_fn()
59self.generate_kwargs = _set_default_generate_kwargs(generate_kwargs, actor)
60if eval_performance:
61actor_numel = get_model_numel(actor)
62critic_numel = get_model_numel(critic)
63initial_model_numel = get_model_numel(initial_model)
64reward_model_numel = get_model_numel(reward_model)
65evaluator = ExperienceMakerPerformanceEvaluator(
66actor_numel, critic_numel, initial_model_numel, reward_model_numel
67)
68callbacks = callbacks + [evaluator]
69
70actor, critic, reward_model, initial_model = self.strategy.prepare(actor, critic, reward_model, initial_model)
71self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, self.kl_coef)
72self.callbacks = callbacks
73
74self._model_visit_lock = Lock()
75
76self._is_fully_initialized = not sync_models_from_trainers
77
78self._debug = debug
79self._update_lora_weights = update_lora_weights
80if self._update_lora_weights:
81self.actor_lora_constructor = LoRAConstructor()
82self.critic_lora_constructor = LoRAConstructor()
83
84self.target_auto_balance = False
85
86self._target_idx = 0
87
88if self._debug:
89print(f"[maker{get_rank()}] will send items to {self._detached_trainer_name_list}")
90if not self._is_fully_initialized:
91print(f"[maker{get_rank()}] Waiting for INIT")
92
93def _get_ready(self):
94while not self._fully_initialized():
95time.sleep(1.0)
96
97def _fully_initialized(self):
98return self._is_fully_initialized
99
100def _init_target_trainer_list(self):
101if len(self.target_trainer_list) > 0:
102return
103for name in self._detached_trainer_name_list:
104self.target_trainer_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))
105
106# copy from ../trainer/base.py
107@ray.method(concurrency_group="compute")
108def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
109if isinstance(inputs, Tensor):
110return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
111elif isinstance(inputs, dict):
112return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
113else:
114raise ValueError(f'Unsupported input type "{type(inputs)}"')
115
116@ray.method(concurrency_group="experience_io")
117def _send_items(self, experience: Experience) -> None:
118self._init_target_trainer_list()
119items = split_experience_batch(experience)
120items_per_trainer = [[] for _ in range(len(self.target_trainer_list))]
121for item in items:
122items_per_trainer[self._target_idx].append(item)
123self._target_idx = (self._target_idx + 1) % len(self.target_trainer_list)
124for i, target_trainer in enumerate(self.target_trainer_list):
125if len(items_per_trainer[i]) > 0:
126target_trainer.buffer_extend.remote(items_per_trainer[i])
127
128def _inference_step(self, batch) -> None:
129self._on_batch_start()
130with self._model_visit_lock:
131self._on_make_experience_start()
132experience = self._make_experience(batch)
133self._on_make_experience_end(experience)
134self._on_send_start()
135if self.buffer_cpu_offload:
136experience.to_device("cpu")
137self._send_items(experience)
138self._on_send_end()
139self._on_batch_end()
140
141def workingloop(self, dataloader_fn: Callable[[], Iterable], num_epochs: int = 1, num_steps: int = 0):
142"""Working loop of the experience maker.
143
144Args:
145dataloader_fn (Callable[[], Iterable]): A function that returns a dataloader.
146num_epochs (int, optional): Iterate the dataloader for number of epochs. Defaults to 1.
147num_steps (int, optional): Iterate the dataloader for number if steps. If this value > 0, num_epochs will be ignored. Defaults to 0.
148"""
149self._get_ready()
150self._on_loop_start()
151dataloader = dataloader_fn()
152if num_steps > 0:
153# ignore num epochs
154it = iter(dataloader)
155for _ in tqdm(range(num_steps), desc="ExperienceMaker", disable=not is_rank_0()):
156try:
157batch = next(it)
158except StopIteration:
159it = iter(dataloader)
160batch = next(it)
161self._inference_step(batch)
162else:
163with tqdm(total=num_epochs * len(dataloader), desc="ExperienceMaker", disable=not is_rank_0()) as pbar:
164for _ in range(num_epochs):
165for batch in dataloader:
166self._inference_step(batch)
167pbar.update()
168self._on_loop_end()
169
170@ray.method(concurrency_group="model_io")
171def update_experience_maker(
172self,
173new_actor_state_dict: Dict[str, Any] = None,
174new_actor_lora_config_dict: Dict[str, Any] = None,
175new_critic_state_dict: Dict[str, Any] = None,
176new_critic_lora_config_dict: Dict[str, Any] = None,
177fully_update: bool = False,
178chunk_start: bool = None,
179chunk_end: bool = None,
180):
181"""
182called by trainer
183chunk_start: Set True at the first call. Before sending state_dict calls
184chunk_end: Set True at the last call. After sending state_dict calls.
185fully_update: Set True if you want to sync models when initializing
186
187TODO: load_state_dict integrate with model-sharding strategy
188"""
189_watch_memory = self._debug
190if chunk_start:
191if self._debug:
192print("[maker] UPDATE ")
193if _watch_memory:
194tracemalloc.start()
195self._model_visit_lock.acquire()
196
197with torch.no_grad():
198if new_actor_state_dict is not None:
199if not self._update_lora_weights or fully_update:
200self.experience_maker.actor.model.load_state_dict(new_actor_state_dict, strict=False)
201else:
202new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device())
203state_dict_increase = self.actor_lora_constructor.reconstruct_increase(
204new_actor_state_dict, new_actor_lora_config_dict
205)
206self.actor_lora_constructor.load_state_dict_increase(
207self.experience_maker.actor.model, state_dict_increase
208)
209if new_critic_state_dict is not None:
210if not self._update_lora_weights or fully_update:
211self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False)
212else:
213new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device())
214state_dict_increase = self.critic_lora_constructor.reconstruct_increase(
215new_critic_state_dict, new_critic_lora_config_dict
216)
217self.critic_lora_constructor.load_state_dict_increase(
218self.experience_maker.critic, state_dict_increase
219)
220
221# the lock must be released after both actor and critic being updated
222if chunk_end:
223self._model_visit_lock.release()
224if _watch_memory:
225current, peak = tracemalloc.get_traced_memory()
226print(f"Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB")
227tracemalloc.stop()
228if fully_update:
229self._is_fully_initialized = True
230
231def _on_make_experience_start(self) -> None:
232for callback in self.callbacks:
233callback.on_make_experience_start()
234
235def _on_make_experience_end(self, experience: Experience) -> None:
236for callback in self.callbacks:
237callback.on_make_experience_end(experience)
238
239def _on_loop_start(self) -> None:
240for callback in self.callbacks:
241callback.on_loop_start()
242
243def _on_loop_end(self) -> None:
244for callback in self.callbacks:
245callback.on_loop_end()
246
247def _on_send_start(self) -> None:
248for callback in self.callbacks:
249callback.on_send_start()
250
251def _on_send_end(self) -> None:
252for callback in self.callbacks:
253callback.on_send_end()
254
255def _on_batch_start(self) -> None:
256for callback in self.callbacks:
257callback.on_batch_start()
258
259def _on_batch_end(self) -> None:
260for callback in self.callbacks:
261callback.on_batch_end()
262
263
264def _set_default_generate_kwargs(generate_kwargs: dict, actor: Actor) -> None:
265origin_model = actor.model
266new_kwargs = {**generate_kwargs}
267# use huggingface models method directly
268if "prepare_inputs_fn" not in generate_kwargs and hasattr(origin_model, "prepare_inputs_for_generation"):
269new_kwargs["prepare_inputs_fn"] = origin_model.prepare_inputs_for_generation
270
271if "update_model_kwargs_fn" not in generate_kwargs and hasattr(origin_model, "_update_model_kwargs_for_generation"):
272new_kwargs["update_model_kwargs_fn"] = origin_model._update_model_kwargs_for_generation
273
274return new_kwargs
275