lmops

Форк
0
/
base_llm.py 
30 строк · 970.0 Байт
1
import torch
2
import torch.nn as nn
3

4
from abc import abstractmethod
5
from typing import List, Optional, Union
6

7

8
class BaseLLM(nn.Module):
9

10
    def __init__(self, model_name_or_path: str, *args, **kwargs):
11
        super().__init__(*args, **kwargs)
12
        self.model_name_or_path = model_name_or_path
13

14
    @abstractmethod
15
    def batch_score(self, input_texts: List[str], output_texts: List[str], **kwargs) -> List[float]:
16
        raise NotImplementedError
17

18
    def score(self, input_text: str, output_text: str, **kwargs) -> float:
19
        return self.batch_score([input_text], [output_text], **kwargs)[0]
20

21
    @abstractmethod
22
    def batch_decode(self, input_texts: List[str], **kwargs) -> List[str]:
23
        raise NotImplementedError
24

25
    def decode(self, input_text: str, **kwargs) -> str:
26
        return self.batch_decode([input_text], **kwargs)[0]
27

28
    def cuda(self, device: Optional[Union[int, torch.device]] = 0):
29
        self.model.to(device)
30
        return self
31

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.