colossalai
36 строк · 980.0 Байт
1from typing import Optional2
3import torch.nn as nn4from transformers import BloomConfig, BloomModel5
6from ..base import Critic7
8
9class BLOOMCritic(Critic):10"""11BLOOM Critic model.
12
13Args:
14pretrained (str): Pretrained model name or path.
15config (BloomConfig): Model config.
16lora_rank (int): LoRA rank.
17lora_train_bias (str): LoRA bias training mode.
18"""
19
20def __init__(21self,22pretrained: str = None,23config: Optional[BloomConfig] = None,24lora_rank: int = 0,25lora_train_bias: str = "none",26**kwargs,27) -> None:28if pretrained is not None:29model = BloomModel.from_pretrained(pretrained)30elif config is not None:31model = BloomModel(config)32else:33model = BloomModel(BloomConfig())34
35value_head = nn.Linear(model.config.hidden_size, 1)36super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)37