lmops
30 строк · 970.0 Байт
1import torch2import torch.nn as nn3
4from abc import abstractmethod5from typing import List, Optional, Union6
7
8class BaseLLM(nn.Module):9
10def __init__(self, model_name_or_path: str, *args, **kwargs):11super().__init__(*args, **kwargs)12self.model_name_or_path = model_name_or_path13
14@abstractmethod15def batch_score(self, input_texts: List[str], output_texts: List[str], **kwargs) -> List[float]:16raise NotImplementedError17
18def score(self, input_text: str, output_text: str, **kwargs) -> float:19return self.batch_score([input_text], [output_text], **kwargs)[0]20
21@abstractmethod22def batch_decode(self, input_texts: List[str], **kwargs) -> List[str]:23raise NotImplementedError24
25def decode(self, input_text: str, **kwargs) -> str:26return self.batch_decode([input_text], **kwargs)[0]27
28def cuda(self, device: Optional[Union[int, torch.device]] = 0):29self.model.to(device)30return self31