h2o-llmstudio

Форк
0
140 строк · 4.7 Кб
1
import codecs
2
import json
3
import logging
4
import os
5
from typing import Any
6

7
from transformers import AutoTokenizer
8

9
logger = logging.getLogger(__name__)
10

11

12
TEXT_SEPARATOR = "<TEXT_SEPARATOR>"
13

14

15
def get_texts(df, cfg, separator=None):
16
    if isinstance(cfg.dataset.prompt_column, str):
17
        # single column dataset
18
        texts = df[cfg.dataset.prompt_column].astype(str)
19
        texts = texts.values
20
    else:
21
        # multi-column dataset - prepend (if necessary) and join
22
        columns = list(cfg.dataset.prompt_column)
23

24
        for column in columns:
25
            df[column] = df[column].astype(str)
26

27
        if separator is None:
28
            separator = getattr(cfg, "_tokenizer_sep_token", TEXT_SEPARATOR)
29

30
        join_str = f" {separator} "
31
        texts = df[columns].astype(str)
32
        texts = texts.apply(lambda x: join_str.join(x), axis=1).values
33

34
    return texts
35

36

37
def get_tokenizer(cfg: Any):
38

39
    kwargs = dict(
40
        revision=cfg.environment.huggingface_branch,
41
        trust_remote_code=cfg.environment.trust_remote_code,
42
        token=os.getenv("HUGGINGFACE_TOKEN"),
43
    )
44

45
    kwargs.update(json.loads(cfg.tokenizer.tokenizer_kwargs.strip()))
46

47
    # We will be able to remove this after
48
    # https://github.com/huggingface/transformers/pull/30964
49
    tokenizer_class = AutoTokenizer.from_pretrained(cfg.llm_backbone).__class__
50
    if tokenizer_class.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
51
        kwargs["from_slow"] = True
52

53
    try:
54
        tokenizer = AutoTokenizer.from_pretrained(cfg.llm_backbone, **kwargs)
55
    except TypeError as e:
56
        error_message = str(e)
57
        if "token" in error_message:
58
            # TypeError: RWForCausalLM.__init__() got
59
            # an unexpected keyword argument 'token'
60
            kwargs.pop("token")
61
            tokenizer = AutoTokenizer.from_pretrained(cfg.llm_backbone, **kwargs)
62
        elif "not a string" in error_message:
63
            # https://github.com/h2oai/h2o-llmstudio/issues/623
64
            kwargs.pop("add_prefix_space")
65
            tokenizer = AutoTokenizer.from_pretrained(cfg.llm_backbone, **kwargs)
66
        else:
67
            raise e
68

69
    tokenizer.padding_side = getattr(
70
        cfg.tokenizer, "_padding_side", tokenizer.padding_side
71
    )
72

73
    # if the eos token is an empty string, we assign it to a token
74
    if tokenizer.eos_token == "":
75
        tokenizer.add_special_tokens({"eos_token": "</s>"})
76
        tokenizer.eos_token = "</s>"
77

78
    if tokenizer.pad_token is None:
79
        if tokenizer.unk_token is not None:
80
            tokenizer.pad_token = tokenizer.unk_token
81
        else:
82
            tokenizer.pad_token = tokenizer.eos_token
83
    if tokenizer.bos_token is None:
84
        tokenizer.bos_token = tokenizer.eos_token
85
    if tokenizer.cls_token is None:
86
        tokenizer.cls_token = tokenizer.eos_token
87
    if tokenizer.sep_token is None:
88
        tokenizer.sep_token = tokenizer.eos_token
89

90
    cfg._tokenizer_sep_token = tokenizer.sep_token
91

92
    if tokenizer.unk_token_id is not None:
93
        cfg._tokenizer_mask_token_id = tokenizer.unk_token_id
94
    elif tokenizer.mask_token_id is not None:
95
        cfg._tokenizer_mask_token_id = tokenizer.mask_token_id
96
    elif tokenizer.pad_token_id is not None:
97
        cfg._tokenizer_mask_token_id = tokenizer.pad_token_id
98
    else:
99
        # setting the mask token id to the last token in the vocabulary
100
        # this usually is a safe choice and mostly refers to eos token
101
        cfg._tokenizer_mask_token_id = len(tokenizer) - 1
102

103
    cfg._tokenizer_eos_token = tokenizer.eos_token
104

105
    if hasattr(cfg.prediction, "stop_tokens"):
106
        set_stop_token_ids(cfg, tokenizer)
107
    cfg.tokenizer._vocab_length = len(tokenizer)
108

109
    return tokenizer
110

111

112
def set_stop_token_ids(cfg, tokenizer):
113
    cfg.tokenizer._stop_words = list(
114
        filter(None, cfg.prediction.stop_tokens.split(","))
115
    )
116
    for stop_word in [
117
        cfg.dataset.text_system_start,
118
        cfg.dataset.text_prompt_start,
119
        cfg.dataset.text_answer_separator,
120
    ]:
121
        stop_word = codecs.decode(stop_word, "unicode_escape").strip()
122
        if (
123
            stop_word != ""
124
            and cfg.tokenizer.add_prompt_answer_tokens
125
            and (stop_word not in tokenizer.get_vocab())
126
        ):
127
            tokenizer.add_tokens([stop_word])
128
        cfg.tokenizer._stop_words.append(stop_word)
129
    cfg.tokenizer._stop_words = [
130
        stop_word for stop_word in cfg.tokenizer._stop_words if stop_word != ""
131
    ]
132
    cfg.tokenizer._stop_words_ids = []
133
    for stop_word in set(cfg.tokenizer._stop_words):
134
        cfg.tokenizer._stop_words_ids.append(
135
            tokenizer(stop_word, return_tensors="pt", add_special_tokens=False)[
136
                "input_ids"
137
            ][0]
138
        )
139
    if cfg.environment._local_rank == 0:
140
        logger.info(f"Stop token ids: {cfg.tokenizer._stop_words_ids}")
141

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.