aurora

Форк
0
49 строк · 2.0 Кб
1
import json
2
import torch
3
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
4

5
from llmtuner.extras.packages import is_requests_available
6

7
if TYPE_CHECKING:
8
    from transformers import PreTrainedModel
9
    from trl import AutoModelForCausalLMWithValueHead
10

11
if is_requests_available():
12
    import requests
13

14

15
def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.Tensor]:
16
    headers = {"Content-Type": "application/json"}
17
    payload = {"model": "model", "messages": messages}
18
    response = requests.post(server_url, json=payload, headers=headers)
19
    rewards = json.loads(response.text)["scores"]
20
    return torch.Tensor(rewards)
21

22

23
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
24
    if target == "reward": # save default head temporarily
25
        valuehead_state_dict: Dict[str, torch.Tensor] = model.v_head.state_dict()
26
        setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].detach().clone())
27
        setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].detach().clone())
28

29
    model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
30
    model.v_head.load_state_dict({
31
        "summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(),
32
        "summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone()
33
    })
34

35

36
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
37
    layer_norm_params = {}
38
    for name, param in model.named_parameters():
39
        if param.data.dtype == torch.float32:
40
            layer_norm_params[name] = param.data.detach().clone()
41
            param.data = param.data.to(model.config.torch_dtype)
42

43
    return layer_norm_params
44

45

46
def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, torch.Tensor]] = None) -> None:
47
    for name, param in model.named_parameters():
48
        if name in layernorm_params:
49
            param.data = layernorm_params[name]
50

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.