colossalai
37 строк · 1.1 Кб
1from typing import Optional
2
3import torch.nn as nn
4from transformers.models.opt.configuration_opt import OPTConfig
5from transformers.models.opt.modeling_opt import OPTModel
6
7from ..base import Critic
8
9
10class OPTCritic(Critic):
11"""
12OPT Critic model.
13
14Args:
15pretrained (str): Pretrained model name or path.
16config (OPTConfig): Model config.
17lora_rank (int): Rank of the low-rank approximation.
18lora_train_bias (str): LoRA bias training mode.
19"""
20
21def __init__(
22self,
23pretrained: Optional[str] = None,
24config: Optional[OPTConfig] = None,
25lora_rank: int = 0,
26lora_train_bias: str = "none",
27**kwargs,
28) -> None:
29if pretrained is not None:
30model = OPTModel.from_pretrained(pretrained)
31elif config is not None:
32model = OPTModel(config)
33else:
34model = OPTModel(OPTConfig())
35
36value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
37super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
38