colossalai
37 строк · 1.1 Кб
1from typing import Optional2
3import torch.nn as nn4from transformers.models.gpt2.configuration_gpt2 import GPT2Config5from transformers.models.gpt2.modeling_gpt2 import GPT2Model6
7from ..base import Critic8
9
10class GPTCritic(Critic):11"""12GPT Critic model.
13
14Args:
15pretrained (str): Pretrained model name or path.
16config (GPT2Config): Model config.
17lora_rank (int): Rank of the LO-RA decomposition.
18lora_train_bias (str): LoRA bias training mode.
19"""
20
21def __init__(22self,23pretrained: Optional[str] = None,24config: Optional[GPT2Config] = None,25lora_rank: int = 0,26lora_train_bias: str = "none",27**kwargs,28) -> None:29if pretrained is not None:30model = GPT2Model.from_pretrained(pretrained)31elif config is not None:32model = GPT2Model(config)33else:34model = GPT2Model(GPT2Config())35
36value_head = nn.Linear(model.config.n_embd, 1)37super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)38