2
from transformers import LlamaForCausalLM, LlamaTokenizer, GenerationConfig
3
from transformers.optimization import get_linear_schedule_with_warmup
7
from util.inference import generate_stream, SimpleChatIO
9
def get_model_tokenizer(model_name_or_path):
10
model = LlamaForCausalLM.from_pretrained(model_name_or_path, device_map="auto", torch_dtype=torch.float16)
11
tokenizer = LlamaTokenizer.from_pretrained(model_name_or_path)
12
return model, tokenizer
19
temperature: float = 0.7,
20
max_new_tokens: int = 2000,
21
chatio = SimpleChatIO(),
25
conv = [system_prompt]
29
inp = chatio.prompt_for_input("User")
36
conv.append("User: " + inp.strip() + tokenizer.eos_token)
39
prompt = "\n".join(conv) + "\nAssistant: "
43
"temperature": temperature,
44
"max_new_tokens": max_new_tokens,
48
chatio.prompt_for_output("Assistant")
49
with torch.inference_mode():
50
output_stream = generate_stream(model, tokenizer, gen_params, device)
51
outputs = chatio.stream_output(output_stream)
52
# NOTE: strip is important to align with the training data.
53
conv.append("Assistant: " + outputs.strip() + tokenizer.eos_token)
56
print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
59
if __name__ == "__main__":
60
parser = argparse.ArgumentParser()
61
parser.add_argument("--model_path", type=str, default="/path/to/ultralm")
62
args = parser.parse_args()
64
model, tokenizer = get_model_tokenizer(args.model_path)
66
system_prompt = "User: A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, very detailed, and polite answers to the user's questions.</s>"
68
chat_loop(model, tokenizer, system_prompt)