moe-llava
/
predict.py
53 строки · 2.1 Кб
1import torch
2from PIL import Image
3from moellava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
4from moellava.conversation import conv_templates, SeparatorStyle
5from moellava.model.builder import load_pretrained_model
6from moellava.utils import disable_torch_init
7from moellava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
8
9def main():
10disable_torch_init()
11image = 'moellava/serve/examples/extreme_ironing.jpg'
12inp = 'What is unusual about this image?'
13model_path = 'LanguageBind/MoE-LLaVA-xxxxxxxxxxxxxxxx' # choose a model
14device = 'cuda'
15load_4bit, load_8bit = False, False
16model_name = get_model_name_from_path(model_path)
17tokenizer, model, processor, context_len = load_pretrained_model(model_path, None, model_name, load_8bit, load_4bit, device=device)
18image_processor = processor['image']
19conv_mode = "phi" # phi or qwen or stablelm
20conv = conv_templates[conv_mode].copy()
21roles = conv.roles
22
23image_tensor = image_processor.preprocess(Image.open(image).convert('RGB'), return_tensors='pt')['pixel_values'].to(model.device, dtype=torch.float16)
24
25
26print(f"{roles[1]}: {inp}")
27inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
28conv.append_message(conv.roles[0], inp)
29conv.append_message(conv.roles[1], None)
30prompt = conv.get_prompt()
31input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
32stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
33keywords = [stop_str]
34stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
35
36with torch.inference_mode():
37output_ids = model.generate(
38input_ids,
39images=image_tensor,
40do_sample=True,
41temperature=0.2,
42max_new_tokens=1024,
43use_cache=True,
44stopping_criteria=[stopping_criteria])
45
46outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:], skip_special_tokens=True).strip()
47print(outputs)
48
49if __name__ == '__main__':
50main()
51'''
52deepspeed predict.py
53'''
54