colossalai

Форк
0
36 строк · 980.0 Байт
1
from typing import Optional
2

3
import torch.nn as nn
4
from transformers import BloomConfig, BloomModel
5

6
from ..base import Critic
7

8

9
class BLOOMCritic(Critic):
10
    """
11
    BLOOM Critic model.
12

13
    Args:
14
        pretrained (str): Pretrained model name or path.
15
        config (BloomConfig): Model config.
16
        lora_rank (int): LoRA rank.
17
        lora_train_bias (str): LoRA bias training mode.
18
    """
19

20
    def __init__(
21
        self,
22
        pretrained: str = None,
23
        config: Optional[BloomConfig] = None,
24
        lora_rank: int = 0,
25
        lora_train_bias: str = "none",
26
        **kwargs,
27
    ) -> None:
28
        if pretrained is not None:
29
            model = BloomModel.from_pretrained(pretrained)
30
        elif config is not None:
31
            model = BloomModel(config)
32
        else:
33
            model = BloomModel(BloomConfig())
34

35
        value_head = nn.Linear(model.config.hidden_size, 1)
36
        super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
37

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.