colossalai
36 строк · 1.1 Кб
1from typing import Optional2
3import torch.nn as nn4from transformers import OPTConfig, OPTModel5
6from ..base import RewardModel7
8
9class OPTRM(RewardModel):10"""11OPT Reward model.
12
13Args:
14pretrained (str): Pretrained model name or path.
15config (OPTConfig): Model config.
16lora_rank (int): Rank of the low-rank approximation.
17lora_train_bias (str): LoRA bias training mode.
18"""
19
20def __init__(21self,22pretrained: Optional[str] = None,23config: Optional[OPTConfig] = None,24lora_rank: int = 0,25lora_train_bias: str = "none",26) -> None:27if pretrained is not None:28model = OPTModel.from_pretrained(pretrained)29elif config is not None:30model = OPTModel(config)31else:32model = OPTModel(OPTConfig())33
34value_head = nn.Linear(model.config.word_embed_proj_dim, 1)35value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.word_embed_proj_dim + 1))36super().__init__(model, value_head, lora_rank, lora_train_bias)37