colossalai

Форк
0
37 строк · 1.1 Кб
1
from typing import Optional
2

3
import torch.nn as nn
4
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
5
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
6

7
from ..base import Critic
8

9

10
class GPTCritic(Critic):
11
    """
12
    GPT Critic model.
13

14
    Args:
15
        pretrained (str): Pretrained model name or path.
16
        config (GPT2Config): Model config.
17
        lora_rank (int): Rank of the LO-RA decomposition.
18
        lora_train_bias (str): LoRA bias training mode.
19
    """
20

21
    def __init__(
22
        self,
23
        pretrained: Optional[str] = None,
24
        config: Optional[GPT2Config] = None,
25
        lora_rank: int = 0,
26
        lora_train_bias: str = "none",
27
        **kwargs,
28
    ) -> None:
29
        if pretrained is not None:
30
            model = GPT2Model.from_pretrained(pretrained)
31
        elif config is not None:
32
            model = GPT2Model(config)
33
        else:
34
            model = GPT2Model(GPT2Config())
35

36
        value_head = nn.Linear(model.config.n_embd, 1)
37
        super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
38

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.