UltraChat

Форк
0
/
inference_cli.py 
68 строк · 2.1 Кб
1
import torch
2
from transformers import LlamaForCausalLM, LlamaTokenizer, GenerationConfig
3
from transformers.optimization import get_linear_schedule_with_warmup
4
from tqdm import tqdm
5
import os
6
import argparse
7
from util.inference import generate_stream, SimpleChatIO
8

9
def get_model_tokenizer(model_name_or_path):
10
    model = LlamaForCausalLM.from_pretrained(model_name_or_path, device_map="auto", torch_dtype=torch.float16)
11
    tokenizer = LlamaTokenizer.from_pretrained(model_name_or_path)
12
    return model, tokenizer
13

14

15
def chat_loop(
16
    model,
17
    tokenizer,
18
    system_prompt: str,
19
    temperature: float = 0.7,
20
    max_new_tokens: int = 2000,
21
    chatio = SimpleChatIO(),
22
    device = "cuda",
23
    debug: bool = False
24
):
25
    conv = [system_prompt]
26

27
    while True:
28
        try:
29
            inp = chatio.prompt_for_input("User")
30
        except EOFError:
31
            inp = ""
32
        if not inp:
33
            print("exit...")
34
            break
35

36
        conv.append("User: " + inp.strip() + tokenizer.eos_token)
37

38
        
39
        prompt = "\n".join(conv) + "\nAssistant: "
40

41
        gen_params = {
42
            "prompt": prompt,
43
            "temperature": temperature,
44
            "max_new_tokens": max_new_tokens,
45
            "echo": False,
46
        }
47

48
        chatio.prompt_for_output("Assistant")
49
        with torch.inference_mode():
50
            output_stream = generate_stream(model, tokenizer, gen_params, device)
51
        outputs = chatio.stream_output(output_stream)
52
        # NOTE: strip is important to align with the training data.
53
        conv.append("Assistant: " + outputs.strip() + tokenizer.eos_token)
54

55
        if debug:
56
            print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
57

58

59
if __name__ == "__main__":
60
    parser = argparse.ArgumentParser()
61
    parser.add_argument("--model_path", type=str, default="/path/to/ultralm")
62
    args = parser.parse_args()
63

64
    model, tokenizer = get_model_tokenizer(args.model_path)
65

66
    system_prompt = "User: A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, very detailed, and polite answers to the user's questions.</s>"
67

68
    chat_loop(model, tokenizer, system_prompt)
69

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.