3
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
5
from llmtuner.extras.packages import is_requests_available
8
from transformers import PreTrainedModel
9
from trl import AutoModelForCausalLMWithValueHead
11
if is_requests_available():
15
def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.Tensor]:
16
headers = {"Content-Type": "application/json"}
17
payload = {"model": "model", "messages": messages}
18
response = requests.post(server_url, json=payload, headers=headers)
19
rewards = json.loads(response.text)["scores"]
20
return torch.Tensor(rewards)
23
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
24
if target == "reward":
25
valuehead_state_dict: Dict[str, torch.Tensor] = model.v_head.state_dict()
26
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].detach().clone())
27
setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].detach().clone())
29
model.pretrained_model.set_adapter(target)
30
model.v_head.load_state_dict({
31
"summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(),
32
"summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone()
36
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
37
layer_norm_params = {}
38
for name, param in model.named_parameters():
39
if param.data.dtype == torch.float32:
40
layer_norm_params[name] = param.data.detach().clone()
41
param.data = param.data.to(model.config.torch_dtype)
43
return layer_norm_params
46
def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, torch.Tensor]] = None) -> None:
47
for name, param in model.named_parameters():
48
if name in layernorm_params:
49
param.data = layernorm_params[name]