aurora

Форк
0
/
chat_model.py 
172 строки · 6.6 Кб
1
import torch
2
import tiktoken
3
from dataclasses import dataclass
4
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
5
from threading import Thread
6
from transformers import GenerationConfig, TextIteratorStreamer
7

8
from llmtuner.data.template import get_template_and_fix_tokenizer
9
from llmtuner.extras.misc import get_logits_processor
10
from llmtuner.model import dispatch_model, get_infer_args, load_model_and_tokenizer
11

12

13
@dataclass
14
class Response:
15

16
    response_text: str
17
    response_length: int
18
    prompt_length: int
19
    finish_reason: Literal["stop", "length"]
20

21

22
class ChatModel:
23

24
    def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
25
        model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
26
        self.can_generate = (finetuning_args.stage == "sft")
27
        self.model, self.tokenizer = load_model_and_tokenizer(
28
            model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
29
        )
30
        self.tokenizer.padding_side = "left" if self.can_generate else "right"
31
        self.model = dispatch_model(self.model)
32
        self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
33

34
    def _process_args(
35
        self,
36
        query: str,
37
        history: Optional[List[Tuple[str, str]]] = None,
38
        system: Optional[str] = None,
39
        **input_kwargs
40
    ) -> Tuple[Dict[str, Any], int]:
41
        prompt, _ = self.template.encode_oneturn(
42
            tokenizer=self.tokenizer, query=query, resp="", history=history, system=system
43
        )
44
        prompt_length = len(prompt)
45
        input_ids = torch.tensor([prompt], device=self.model.device)
46

47
        do_sample = input_kwargs.pop("do_sample", None)
48
        temperature = input_kwargs.pop("temperature", None)
49
        top_p = input_kwargs.pop("top_p", None)
50
        top_k = input_kwargs.pop("top_k", None)
51
        num_return_sequences = input_kwargs.pop("num_return_sequences", None)
52
        repetition_penalty = input_kwargs.pop("repetition_penalty", None)
53
        max_length = input_kwargs.pop("max_length", None)
54
        max_new_tokens = input_kwargs.pop("max_new_tokens", None)
55

56
        generating_args = self.generating_args.to_dict()
57
        generating_args.update(dict(
58
            do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
59
            temperature=temperature or generating_args["temperature"],
60
            top_p=top_p or generating_args["top_p"],
61
            top_k=top_k or generating_args["top_k"],
62
            num_return_sequences=num_return_sequences or 1,
63
            repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
64
            eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
65
            pad_token_id=self.tokenizer.pad_token_id
66
        ))
67

68
        if isinstance(num_return_sequences, int) and num_return_sequences > 1:
69
            generating_args["do_sample"] = True
70

71
        if max_length:
72
            generating_args.pop("max_new_tokens", None)
73
            generating_args["max_length"] = max_length
74

75
        if max_new_tokens:
76
            generating_args.pop("max_length", None)
77
            generating_args["max_new_tokens"] = max_new_tokens
78

79
        gen_kwargs = dict(
80
            inputs=input_ids,
81
            generation_config=GenerationConfig(**generating_args),
82
            logits_processor=get_logits_processor()
83
        )
84

85
        return gen_kwargs, prompt_length
86

87
    @torch.inference_mode()
88
    def chat(
89
        self,
90
        query: str,
91
        history: Optional[List[Tuple[str, str]]] = None,
92
        system: Optional[str] = None,
93
        **input_kwargs
94
    ) -> List[Response]:
95
        r"""
96
        Args: query, history, system, **input_kwargs
97

98
        Returns: [(response_text, prompt_length, response_length)] * n (default n=1)
99
        """
100
        gen_kwargs, prompt_length = self._process_args(query, history, system, **input_kwargs)
101
        generate_output = self.model.generate(**gen_kwargs)
102
        response_ids = generate_output[:, prompt_length:]
103
        response = self.tokenizer.batch_decode(
104
            response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
105
        )
106
        results = []
107
        for i in range(len(response)):
108
            eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
109
            response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
110
            results.append(Response(
111
                response_text=response[i],
112
                response_length=response_length,
113
                prompt_length=prompt_length,
114
                finish_reason="stop" if len(eos_index) else "length"
115
            ))
116

117
        return results
118

119
    @torch.inference_mode()
120
    def stream_chat(
121
        self,
122
        query: str,
123
        history: Optional[List[Tuple[str, str]]] = None,
124
        system: Optional[str] = None,
125
        **input_kwargs
126
    ) -> Generator[str, None, None]:
127
        gen_kwargs, _ = self._process_args(query, history, system, **input_kwargs)
128
        streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
129
        gen_kwargs["streamer"] = streamer
130

131
        thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
132
        thread.start()
133

134
        yield from streamer
135

136
    @torch.inference_mode()
137
    def get_scores(
138
        self,
139
        batch_input: List[str],
140
        **input_kwargs
141
    ) -> List[float]:
142
        if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
143
            kwargs = dict(allowed_special="all")
144
        else:
145
            kwargs = dict(add_special_tokens=True)
146

147
        max_length = input_kwargs.pop("max_length", None)
148
        device = getattr(self.model.pretrained_model, "device", "cuda")
149

150
        inputs = self.tokenizer(
151
            batch_input,
152
            padding=True,
153
            truncation=True,
154
            max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
155
            pad_to_multiple_of=8,
156
            return_tensors="pt",
157
            **kwargs
158
        ).to(device)
159

160
        input_ids: torch.Tensor = inputs["input_ids"]
161
        _, _, values = self.model(**inputs, output_hidden_states=True, return_dict=True)
162

163
        if getattr(self.model.config, "model_type", None) == "chatglm":
164
            values = torch.transpose(values, 0, 1)
165

166
        scores = []
167
        for i in range(input_ids.size(0)):
168
            end_indexes = (input_ids[i] != self.tokenizer.pad_token_id).nonzero()
169
            end_index = end_indexes[-1].item() if len(end_indexes) else 0
170
            scores.append(values[i, end_index].nan_to_num().item())
171

172
        return scores
173

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.