lmops

Форк
0
/
model_utils.py 
38 строк · 1.4 Кб
1
import os
2

3
from datasets import Dataset
4

5
from llms import BaseLLM, GPT2, GPTNeo, Llama
6
from evaluation import BaseEval, RandomEval, DenseEval
7
from config import Arguments
8
from logger_config import logger
9

10

11
def build_llm(args: Arguments) -> BaseLLM:
12
    model_name_or_path: str = args.llm_model_name_or_path
13
    if 'gpt2' in model_name_or_path:
14
        if args.llm_max_input_length >= 1024:
15
            args.llm_max_input_length -= max(args.llm_max_decode_length, 128)
16
            logger.warning('GPT2 models cannot handle sequences longer than 1024. '
17
                           'set to {}'.format(args.llm_max_input_length))
18
        llm = GPT2(args=args, model_name_or_path=model_name_or_path)
19
    elif 'gpt-neo' in model_name_or_path:
20
        llm = GPTNeo(args=args, model_name_or_path=model_name_or_path)
21
    elif 'llama' in model_name_or_path:
22
        llm = Llama(args=args, model_name_or_path=model_name_or_path)
23
    else:
24
        raise ValueError('Invalid model name or path: {}'.format(model_name_or_path))
25

26
    return llm
27

28

29
def build_eval_model(args: Arguments, corpus: Dataset) -> BaseEval:
30
    model_name_or_path: str = args.model_name_or_path
31
    if model_name_or_path == 'random':
32
        return RandomEval(args=args, corpus=corpus)
33
    else:
34
        return DenseEval(args=args, corpus=corpus)
35

36

37
def parse_model_id(model_name_or_path: str) -> str:
38
    return os.path.basename(model_name_or_path.strip('/'))[-12:]
39

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.