lmops
38 строк · 1.4 Кб
1import os
2
3from datasets import Dataset
4
5from llms import BaseLLM, GPT2, GPTNeo, Llama
6from evaluation import BaseEval, RandomEval, DenseEval
7from config import Arguments
8from logger_config import logger
9
10
11def build_llm(args: Arguments) -> BaseLLM:
12model_name_or_path: str = args.llm_model_name_or_path
13if 'gpt2' in model_name_or_path:
14if args.llm_max_input_length >= 1024:
15args.llm_max_input_length -= max(args.llm_max_decode_length, 128)
16logger.warning('GPT2 models cannot handle sequences longer than 1024. '
17'set to {}'.format(args.llm_max_input_length))
18llm = GPT2(args=args, model_name_or_path=model_name_or_path)
19elif 'gpt-neo' in model_name_or_path:
20llm = GPTNeo(args=args, model_name_or_path=model_name_or_path)
21elif 'llama' in model_name_or_path:
22llm = Llama(args=args, model_name_or_path=model_name_or_path)
23else:
24raise ValueError('Invalid model name or path: {}'.format(model_name_or_path))
25
26return llm
27
28
29def build_eval_model(args: Arguments, corpus: Dataset) -> BaseEval:
30model_name_or_path: str = args.model_name_or_path
31if model_name_or_path == 'random':
32return RandomEval(args=args, corpus=corpus)
33else:
34return DenseEval(args=args, corpus=corpus)
35
36
37def parse_model_id(model_name_or_path: str) -> str:
38return os.path.basename(model_name_or_path.strip('/'))[-12:]
39