aurora

Форк
0
/
template.py 
765 строк · 18.2 Кб
1
import tiktoken
2
from dataclasses import dataclass
3
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
4

5
from llmtuner.extras.logging import get_logger
6

7
if TYPE_CHECKING:
8
    from transformers import PreTrainedTokenizer
9

10

11
logger = get_logger(__name__)
12

13

14
@dataclass
15
class Template:
16

17
    prefix: List[Union[str, Dict[str, str]]]
18
    prompt: List[Union[str, Dict[str, str]]]
19
    system: str
20
    sep: List[Union[str, Dict[str, str]]]
21
    stop_words: List[str]
22
    use_history: bool
23
    efficient_eos: bool
24

25
    def encode_oneturn(
26
        self,
27
        tokenizer: "PreTrainedTokenizer",
28
        query: str,
29
        resp: str,
30
        history: Optional[List[Tuple[str, str]]] = None,
31
        system: Optional[str] = None
32
    ) -> Tuple[List[int], List[int]]:
33
        r"""
34
        Returns a single pair of token ids representing prompt and response respectively.
35
        """
36
        system, history = self._format(query, resp, history, system)
37
        encoded_pairs = self._encode(tokenizer, system, history)
38
        prompt_ids = []
39
        for query_ids, resp_ids in encoded_pairs[:-1]:
40
            prompt_ids = prompt_ids + query_ids + resp_ids
41
        prompt_ids, answer_ids = prompt_ids + encoded_pairs[-1][0], encoded_pairs[-1][1]
42
        return prompt_ids, answer_ids
43

44
    def encode_multiturn(
45
        self,
46
        tokenizer: "PreTrainedTokenizer",
47
        query: str,
48
        resp: str,
49
        history: Optional[List[Tuple[str, str]]] = None,
50
        system: Optional[str] = None
51
    ) -> List[Tuple[List[int], List[int]]]:
52
        r"""
53
        Returns multiple pairs of token ids representing prompts and responses respectively.
54
        """
55
        system, history = self._format(query, resp, history, system)
56
        encoded_pairs = self._encode(tokenizer, system, history)
57
        return encoded_pairs
58

59
    def _format(
60
        self,
61
        query: str,
62
        resp: str,
63
        history: Optional[List[Tuple[str, str]]] = None,
64
        system: Optional[str] = None
65
    ) -> Tuple[str, List[Tuple[str, str]]]:
66
        r"""
67
        Aligns inputs to the standard format.
68
        """
69
        system = system or self.system # use system if provided
70
        history = history if (history and self.use_history) else []
71
        history = history + [(query, resp)]
72
        return system, history
73

74
    def _get_special_ids(
75
        self,
76
        tokenizer: "PreTrainedTokenizer"
77
    ) -> Tuple[List[int], List[int]]:
78
        if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True):
79
            bos_ids = [tokenizer.bos_token_id]
80
        else: # baichuan, qwen and gpt2 models have no bos token
81
            bos_ids = []
82

83
        if tokenizer.eos_token_id is None:
84
            raise ValueError("EOS token is required.")
85

86
        if self.efficient_eos: # used in baichuan, qwen, chatglm, etc.
87
            eos_ids = []
88
        else:
89
            eos_ids = [tokenizer.eos_token_id]
90

91
        return bos_ids, eos_ids
92

93
    def _encode(
94
        self,
95
        tokenizer: "PreTrainedTokenizer",
96
        system: str,
97
        history: List[Tuple[str, str]]
98
    ) -> List[Tuple[List[int], List[int]]]:
99
        r"""
100
        Encodes formatted inputs to pairs of token ids.
101
        Turn 0: bos + prefix + sep + query    resp + eos
102
        Turn t: sep + bos + query             resp + eos
103
        """
104
        bos_ids, eos_ids = self._get_special_ids(tokenizer)
105
        sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
106
        encoded_pairs = []
107
        for turn_idx, (query, resp) in enumerate(history):
108
            if turn_idx == 0:
109
                prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system)
110
                if len(prefix_ids) != 0: # has prefix
111
                    prefix_ids = bos_ids + prefix_ids + sep_ids
112
                else:
113
                    prefix_ids = bos_ids
114
            else:
115
                prefix_ids = sep_ids + bos_ids
116

117
            query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query, idx=str(turn_idx+1))
118
            resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
119
            encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids))
120
        return encoded_pairs
121

122
    def _convert_inputs_to_ids(
123
        self,
124
        tokenizer: "PreTrainedTokenizer",
125
        context: List[Union[str, Dict[str, str]]],
126
        system: Optional[str] = None,
127
        query: Optional[str] = None,
128
        idx: Optional[str] = None
129
    ) -> List[int]:
130
        r"""
131
        Converts context to token ids.
132
        """
133
        if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
134
            kwargs = dict(allowed_special="all")
135
        else:
136
            kwargs = dict(add_special_tokens=False)
137

138
        token_ids = []
139
        for elem in context:
140
            if isinstance(elem, str):
141
                elem = elem.replace("{{system}}", system, 1) if system is not None else elem
142
                elem = elem.replace("{{query}}", query, 1) if query is not None else elem
143
                elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem
144
                if len(elem) != 0:
145
                    token_ids = token_ids + tokenizer.encode(elem, **kwargs)
146
            elif isinstance(elem, dict):
147
                token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
148
            else:
149
                raise ValueError("Input must be string or dict[str, str], got {}".format(type(elem)))
150

151
        return token_ids
152

153

154
@dataclass
155
class Llama2Template(Template):
156

157
    def _encode(
158
        self,
159
        tokenizer: "PreTrainedTokenizer",
160
        system: str,
161
        history: List[Tuple[str, str]]
162
    ) -> List[Tuple[List[int], List[int]]]:
163
        r"""
164
        Encodes formatted inputs to pairs of token ids.
165
        Turn 0: bos + prefix + query    resp + eos
166
        Turn t: bos + query             resp + eos
167
        """
168
        bos_ids, eos_ids = self._get_special_ids(tokenizer)
169
        encoded_pairs = []
170
        for turn_idx, (query, resp) in enumerate(history):
171
            if turn_idx == 0: # llama2 template has no sep_ids
172
                query = self.prefix[0].replace("{{system}}", system) + query
173
            query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
174
            resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
175
            encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids))
176
        return encoded_pairs
177

178

179
templates: Dict[str, Template] = {}
180

181

182
def register_template(
183
    name: str,
184
    prefix: List[Union[str, Dict[str, str]]],
185
    prompt: List[Union[str, Dict[str, str]]],
186
    system: str,
187
    sep: List[Union[str, Dict[str, str]]],
188
    stop_words: Optional[List[str]] = [],
189
    use_history: Optional[bool] = True,
190
    efficient_eos: Optional[bool] = False
191
) -> None:
192
    template_class = Llama2Template if "llama2" in name else Template
193
    templates[name] = template_class(
194
        prefix=prefix,
195
        prompt=prompt,
196
        system=system,
197
        sep=sep,
198
        stop_words=stop_words,
199
        use_history=use_history,
200
        efficient_eos=efficient_eos
201
    )
202

203

204
def get_template_and_fix_tokenizer(
205
    name: str,
206
    tokenizer: "PreTrainedTokenizer"
207
) -> Template:
208
    if tokenizer.eos_token_id is None:
209
        tokenizer.eos_token = "<|endoftext|>"
210
        logger.info("Add eos token: {}".format(tokenizer.eos_token))
211

212
    if tokenizer.pad_token_id is None:
213
        tokenizer.pad_token = tokenizer.eos_token
214
        logger.info("Add pad token: {}".format(tokenizer.pad_token))
215

216
    if name is None:
217
        return None
218

219
    template = templates.get(name, None)
220
    assert template is not None, "Template {} does not exist.".format(name)
221
    tokenizer.add_special_tokens(
222
        dict(additional_special_tokens=template.stop_words),
223
        replace_additional_special_tokens=False
224
    )
225
    return template
226

227

228
register_template(
229
    name="alpaca",
230
    prefix=[
231
        "{{system}}"
232
    ],
233
    prompt=[
234
        "### Instruction:\n{{query}}\n\n### Response:\n"
235
    ],
236
    system=(
237
        "Below is an instruction that describes a task. "
238
        "Write a response that appropriately completes the request."
239
    ),
240
    sep=[
241
        "\n\n"
242
    ]
243
)
244

245

246
register_template(
247
    name="aquila",
248
    prefix=[
249
        "{{system}}"
250
    ],
251
    prompt=[
252
        "Human: {{query}}###Assistant:"
253
    ],
254
    system=(
255
        "A chat between a curious human and an artificial intelligence assistant. "
256
        "The assistant gives helpful, detailed, and polite answers to the human's questions."
257
    ),
258
    sep=[
259
        "###"
260
    ],
261
    stop_words=[
262
        "</s>"
263
    ],
264
    efficient_eos=True
265
)
266

267

268
register_template(
269
    name="baichuan",
270
    prefix=[
271
        "{{system}}"
272
    ],
273
    prompt=[
274
        {"token": "<reserved_102>"}, # user token
275
        "{{query}}",
276
        {"token": "<reserved_103>"}  # assistant token
277
    ],
278
    system="",
279
    sep=[],
280
    efficient_eos=True
281
)
282

283

284
register_template(
285
    name="baichuan2",
286
    prefix=[
287
        "{{system}}"
288
    ],
289
    prompt=[
290
        {"token": "<reserved_106>"}, # user token
291
        "{{query}}",
292
        {"token": "<reserved_107>"}  # assistant token
293
    ],
294
    system="",
295
    sep=[],
296
    efficient_eos=True
297
)
298

299

300
register_template(
301
    name="belle",
302
    prefix=[
303
        "{{system}}"
304
    ],
305
    prompt=[
306
        "Human: {{query}}\n\nBelle: "
307
    ],
308
    system="",
309
    sep=[
310
        "\n\n"
311
    ]
312
)
313

314

315
register_template(
316
    name="bluelm",
317
    prefix=[
318
        "{{system}}"
319
    ],
320
    prompt=[
321
        {"token": "[|Human|]:"},
322
        "{{query}}",
323
        {"token": "[|AI|]:"}
324
    ],
325
    system="",
326
    sep=[]
327
)
328

329

330
register_template(
331
    name="chatglm2",
332
    prefix=[
333
        {"token": "[gMASK]"},
334
        {"token": "sop"},
335
        "{{system}}"
336
    ],
337
    prompt=[
338
        "[Round {{idx}}]\n\n问:{{query}}\n\n答:"
339
    ],
340
    system="",
341
    sep=[
342
        "\n\n"
343
    ],
344
    efficient_eos=True
345
)
346

347

348
register_template(
349
    name="chatglm3",
350
    prefix=[
351
        {"token": "[gMASK]"},
352
        {"token": "sop"},
353
        {"token": "<|system|>"},
354
        "\n",
355
        "{{system}}"
356
    ],
357
    prompt=[
358
        {"token": "<|user|>"},
359
        "\n",
360
        "{{query}}",
361
        {"token": "<|assistant|>"},
362
        "\n" # add an extra newline to avoid error in ChatGLM's process_response method
363
    ],
364
    system=(
365
        "You are ChatGLM3, a large language model trained by Zhipu.AI. "
366
        "Follow the user's instructions carefully. Respond using markdown."
367
    ),
368
    sep=[],
369
    stop_words=[
370
        "<|user|>",
371
        "<|observation|>"
372
    ],
373
    efficient_eos=True
374
)
375

376

377
register_template(
378
    name="chatglm3_raw", # the raw template for tool tuning
379
    prefix=[
380
        {"token": "[gMASK]"},
381
        {"token": "sop"},
382
        {"token": "<|system|>"},
383
        "\n",
384
        "{{system}}"
385
    ],
386
    prompt=[
387
        {"token": "<|user|>"},
388
        "\n",
389
        "{{query}}",
390
        {"token": "<|assistant|>"}
391
    ],
392
    system=(
393
        "You are ChatGLM3, a large language model trained by Zhipu.AI. "
394
        "Follow the user's instructions carefully. Respond using markdown."
395
    ),
396
    sep=[],
397
    stop_words=[
398
        "<|user|>",
399
        "<|observation|>"
400
    ],
401
    efficient_eos=True
402
)
403

404

405
register_template(
406
    name="deepseek",
407
    prefix=[
408
        "{{system}}"
409
    ],
410
    prompt=[
411
        "User: {{query}}\n\nAssistant:"
412
    ],
413
    system="",
414
    sep=[]
415
)
416

417

418
register_template(
419
    name="deepseekcoder",
420
    prefix=[
421
        "{{system}}"
422
    ],
423
    prompt=[
424
        "### Instruction:\n{{query}}\n### Response:\n"
425
    ],
426
    system=(
427
        "You are an AI programming assistant, utilizing the Deepseek Coder model, "
428
        "developed by Deepseek Company, and you only answer questions related to computer science. "
429
        "For politically sensitive questions, security and privacy issues, "
430
        "and other non-computer science questions, you will refuse to answer\n"
431
    ),
432
    sep=[
433
        "\n",
434
        {"token": "<|EOT|>"},
435
        "\n"
436
    ],
437
    stop_words=[
438
        "<|EOT|>"
439
    ],
440
    efficient_eos=True
441
)
442

443

444
register_template(
445
    name="default",
446
    prefix=[
447
        "{{system}}"
448
    ],
449
    prompt=[
450
        "Human: {{query}}\nAssistant:"
451
    ],
452
    system=(
453
        "A chat between a curious user and an artificial intelligence assistant. "
454
        "The assistant gives helpful, detailed, and polite answers to the user's questions."
455
    ),
456
    sep=[
457
        "\n"
458
    ]
459
)
460

461

462
register_template(
463
    name="falcon",
464
    prefix=[
465
        "{{system}}"
466
    ],
467
    prompt=[
468
        "User: {{query}}\nFalcon:"
469
    ],
470
    system="",
471
    sep=[
472
        "\n"
473
    ],
474
    efficient_eos=True
475
)
476

477

478
register_template(
479
    name="intern",
480
    prefix=[
481
        "{{system}}"
482
    ],
483
    prompt=[
484
        "<|User|>:{{query}}",
485
        {"token": "<eoh>"},
486
        "\n<|Bot|>:"
487
    ],
488
    system="",
489
    sep=[
490
        {"token": "<eoa>"},
491
        "\n"
492
    ],
493
    stop_words=[
494
        "<eoa>"
495
    ],
496
    efficient_eos=True
497
)
498

499

500
register_template(
501
    name="llama2",
502
    prefix=[
503
        "<<SYS>>\n{{system}}\n<</SYS>>\n\n"
504
    ],
505
    prompt=[
506
        "[INST] {{query}} [/INST]"
507
    ],
508
    system=(
509
        "You are a helpful, respectful and honest assistant. "
510
        "Always answer as helpfully as possible, while being safe. "
511
        "Your answers should not include any harmful, unethical, "
512
        "racist, sexist, toxic, dangerous, or illegal content. "
513
        "Please ensure that your responses are socially unbiased and positive in nature.\n\n"
514
        "If a question does not make any sense, or is not factually coherent, "
515
        "explain why instead of answering something not correct. "
516
        "If you don't know the answer to a question, please don't share false information."
517
    ),
518
    sep=[]
519
)
520

521

522
register_template(
523
    name="llama2_zh",
524
    prefix=[
525
        "<<SYS>>\n{{system}}\n<</SYS>>\n\n"
526
    ],
527
    prompt=[
528
        "[INST] {{query}} [/INST]"
529
    ],
530
    system="You are a helpful assistant. 你是一个乐于助人的助手。",
531
    sep=[]
532
)
533

534

535
register_template(
536
    name="mistral",
537
    prefix=[
538
        "{{system}}"
539
    ],
540
    prompt=[
541
        "[INST] {{query}} [/INST]"
542
    ],
543
    system="",
544
    sep=[]
545
)
546

547

548
register_template(
549
    name="openchat",
550
    prefix=[
551
        "{{system}}"
552
    ],
553
    prompt=[
554
        "GPT4 Correct User: {{query}}",
555
        {"token": "<|end_of_turn|>"},
556
        "GPT4 Correct Assistant:"
557
    ],
558
    system="",
559
    sep=[
560
        {"token": "<|end_of_turn|>"}
561
    ],
562
    stop_words=[
563
        "<|end_of_turn|>"
564
    ],
565
    efficient_eos=True
566
)
567

568

569
register_template(
570
    name="qwen",
571
    prefix=[
572
        {"token": "<|im_start|>"},
573
        "system\n{{system}}"
574
    ],
575
    prompt=[
576
        {"token": "<|im_start|>"},
577
        "user\n{{query}}",
578
        {"token": "<|im_end|>"},
579
        "\n",
580
        {"token": "<|im_start|>"},
581
        "assistant\n"
582
    ],
583
    system="You are a helpful assistant.",
584
    sep=[
585
        {"token": "<|im_end|>"},
586
        "\n"
587
    ],
588
    stop_words=[
589
        "<|im_end|>"
590
    ],
591
    efficient_eos=True
592
)
593

594

595
register_template(
596
    name="starchat",
597
    prefix=[
598
        {"token": "<|system|>"},
599
        "\n{{system}}",
600
    ],
601
    prompt=[
602
        {"token": "<|user|>"},
603
        "\n{{query}}",
604
        {"token": "<|end|>"},
605
        "\n",
606
        {"token": "<|assistant|>"}
607
    ],
608
    system="",
609
    sep=[
610
        {"token": "<|end|>"},
611
        "\n"
612
    ],
613
    stop_words=[
614
        "<|end|>"
615
    ],
616
    efficient_eos=True
617
)
618

619

620
r"""
621
Supports language model inference without histories.
622
"""
623
register_template(
624
    name="vanilla",
625
    prefix=[],
626
    prompt=[
627
        "{{query}}"
628
    ],
629
    system="",
630
    sep=[],
631
    use_history=False
632
)
633

634

635
register_template(
636
    name="vicuna",
637
    prefix=[
638
        "{{system}}"
639
    ],
640
    prompt=[
641
        "USER: {{query}} ASSISTANT:"
642
    ],
643
    system=(
644
        "A chat between a curious user and an artificial intelligence assistant. "
645
        "The assistant gives helpful, detailed, and polite answers to the user's questions."
646
    ),
647
    sep=[]
648
)
649

650

651
register_template(
652
    name="xuanyuan",
653
    prefix=[
654
        "{{system}}"
655
    ],
656
    prompt=[
657
        "Human: {{query}} Assistant:"
658
    ],
659
    system=(
660
        "以下是用户和人工智能助手之间的对话。用户以Human开头,人工智能助手以Assistant开头,"
661
        "会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、"
662
        "不安全、有争议、政治敏感等相关的话题、问题和指示。\n"
663
    ),
664
    sep=[]
665
)
666

667

668
register_template(
669
    name="xverse",
670
    prefix=[
671
        "{{system}}"
672
    ],
673
    prompt=[
674
        "Human: {{query}}\n\nAssistant: "
675
    ],
676
    system="",
677
    sep=[]
678
)
679

680

681
register_template(
682
    name="yayi",
683
    prefix=[
684
        {"token": "<|System|>"},
685
        ":\n{{system}}"
686
    ],
687
    prompt=[
688
        {"token": "<|Human|>"},
689
        ":\n{{query}}\n\n",
690
        {"token": "<|YaYi|>"},
691
        ":"
692
    ],
693
    system=(
694
        "You are a helpful, respectful and honest assistant named YaYi "
695
        "developed by Beijing Wenge Technology Co.,Ltd. "
696
        "Always answer as helpfully as possible, while being safe.  "
697
        "Your answers should not include any harmful, unethical, "
698
        "racist, sexist, toxic, dangerous, or illegal content. "
699
        "Please ensure that your responses are socially unbiased and positive in nature.\n\n"
700
        "If a question does not make any sense, or is not factually coherent, "
701
        "explain why instead of answering something not correct. "
702
        "If you don't know the answer to a question, please don't share false information."
703
    ),
704
    sep=[
705
        "\n\n"
706
    ],
707
    stop_words=[
708
        "<|End|>"
709
    ]
710
)
711

712

713
register_template(
714
    name="yi",
715
    prefix=[
716
        "{{system}}"
717
    ],
718
    prompt=[
719
        "<|im_start|>user\n{{query}}<|im_end|>\n<|im_start|>assistant\n"
720
    ],
721
    system="",
722
    sep=[
723
        "<|im_end|>\n"
724
    ],
725
    stop_words=[
726
        "<|im_end|>"
727
    ],
728
    efficient_eos=True
729
)
730

731

732
register_template(
733
    name="zephyr",
734
    prefix=[
735
        {"token": "<|system|>"},
736
        "\n{{system}}",
737
        {"token": "</s>"}
738
    ],
739
    prompt=[
740
        {"token": "<|user|>"},
741
        "\n{{query}}",
742
        {"token": "</s>"},
743
        {"token": "<|assistant|>"}
744
    ],
745
    system="You are a friendly chatbot who always responds in the style of a pirate",
746
    sep=[]
747
)
748

749

750
register_template(
751
    name="ziya",
752
    prefix=[
753
        "{{system}}"
754
    ],
755
    prompt=[
756
        {"token": "<human>"},
757
        ":{{query}}\n",
758
        {"token": "<bot>"},
759
        ":"
760
    ],
761
    system="",
762
    sep=[
763
        "\n"
764
    ]
765
)
766

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.