h2o-llmstudio
140 строк · 4.7 Кб
1import codecs2import json3import logging4import os5from typing import Any6
7from transformers import AutoTokenizer8
9logger = logging.getLogger(__name__)10
11
12TEXT_SEPARATOR = "<TEXT_SEPARATOR>"13
14
15def get_texts(df, cfg, separator=None):16if isinstance(cfg.dataset.prompt_column, str):17# single column dataset18texts = df[cfg.dataset.prompt_column].astype(str)19texts = texts.values20else:21# multi-column dataset - prepend (if necessary) and join22columns = list(cfg.dataset.prompt_column)23
24for column in columns:25df[column] = df[column].astype(str)26
27if separator is None:28separator = getattr(cfg, "_tokenizer_sep_token", TEXT_SEPARATOR)29
30join_str = f" {separator} "31texts = df[columns].astype(str)32texts = texts.apply(lambda x: join_str.join(x), axis=1).values33
34return texts35
36
37def get_tokenizer(cfg: Any):38
39kwargs = dict(40revision=cfg.environment.huggingface_branch,41trust_remote_code=cfg.environment.trust_remote_code,42token=os.getenv("HUGGINGFACE_TOKEN"),43)44
45kwargs.update(json.loads(cfg.tokenizer.tokenizer_kwargs.strip()))46
47# We will be able to remove this after48# https://github.com/huggingface/transformers/pull/3096449tokenizer_class = AutoTokenizer.from_pretrained(cfg.llm_backbone).__class__50if tokenizer_class.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:51kwargs["from_slow"] = True52
53try:54tokenizer = AutoTokenizer.from_pretrained(cfg.llm_backbone, **kwargs)55except TypeError as e:56error_message = str(e)57if "token" in error_message:58# TypeError: RWForCausalLM.__init__() got59# an unexpected keyword argument 'token'60kwargs.pop("token")61tokenizer = AutoTokenizer.from_pretrained(cfg.llm_backbone, **kwargs)62elif "not a string" in error_message:63# https://github.com/h2oai/h2o-llmstudio/issues/62364kwargs.pop("add_prefix_space")65tokenizer = AutoTokenizer.from_pretrained(cfg.llm_backbone, **kwargs)66else:67raise e68
69tokenizer.padding_side = getattr(70cfg.tokenizer, "_padding_side", tokenizer.padding_side71)72
73# if the eos token is an empty string, we assign it to a token74if tokenizer.eos_token == "":75tokenizer.add_special_tokens({"eos_token": "</s>"})76tokenizer.eos_token = "</s>"77
78if tokenizer.pad_token is None:79if tokenizer.unk_token is not None:80tokenizer.pad_token = tokenizer.unk_token81else:82tokenizer.pad_token = tokenizer.eos_token83if tokenizer.bos_token is None:84tokenizer.bos_token = tokenizer.eos_token85if tokenizer.cls_token is None:86tokenizer.cls_token = tokenizer.eos_token87if tokenizer.sep_token is None:88tokenizer.sep_token = tokenizer.eos_token89
90cfg._tokenizer_sep_token = tokenizer.sep_token91
92if tokenizer.unk_token_id is not None:93cfg._tokenizer_mask_token_id = tokenizer.unk_token_id94elif tokenizer.mask_token_id is not None:95cfg._tokenizer_mask_token_id = tokenizer.mask_token_id96elif tokenizer.pad_token_id is not None:97cfg._tokenizer_mask_token_id = tokenizer.pad_token_id98else:99# setting the mask token id to the last token in the vocabulary100# this usually is a safe choice and mostly refers to eos token101cfg._tokenizer_mask_token_id = len(tokenizer) - 1102
103cfg._tokenizer_eos_token = tokenizer.eos_token104
105if hasattr(cfg.prediction, "stop_tokens"):106set_stop_token_ids(cfg, tokenizer)107cfg.tokenizer._vocab_length = len(tokenizer)108
109return tokenizer110
111
112def set_stop_token_ids(cfg, tokenizer):113cfg.tokenizer._stop_words = list(114filter(None, cfg.prediction.stop_tokens.split(","))115)116for stop_word in [117cfg.dataset.text_system_start,118cfg.dataset.text_prompt_start,119cfg.dataset.text_answer_separator,120]:121stop_word = codecs.decode(stop_word, "unicode_escape").strip()122if (123stop_word != ""124and cfg.tokenizer.add_prompt_answer_tokens125and (stop_word not in tokenizer.get_vocab())126):127tokenizer.add_tokens([stop_word])128cfg.tokenizer._stop_words.append(stop_word)129cfg.tokenizer._stop_words = [130stop_word for stop_word in cfg.tokenizer._stop_words if stop_word != ""131]132cfg.tokenizer._stop_words_ids = []133for stop_word in set(cfg.tokenizer._stop_words):134cfg.tokenizer._stop_words_ids.append(135tokenizer(stop_word, return_tensors="pt", add_special_tokens=False)[136"input_ids"137][0]138)139if cfg.environment._local_rank == 0:140logger.info(f"Stop token ids: {cfg.tokenizer._stop_words_ids}")141