llama-index

Форк
0
99 строк · 3.1 Кб
1
from typing import Optional
2

3
import requests
4

5
DEFAULT_HUGGINGFACE_EMBEDDING_MODEL = "BAAI/bge-small-en"
6
DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-base"
7

8
# Originally pulled from:
9
# https://github.com/langchain-ai/langchain/blob/v0.0.257/libs/langchain/langchain/embeddings/huggingface.py#L10
10
DEFAULT_EMBED_INSTRUCTION = "Represent the document for retrieval: "
11
DEFAULT_QUERY_INSTRUCTION = (
12
    "Represent the question for retrieving supporting documents: "
13
)
14
DEFAULT_QUERY_BGE_INSTRUCTION_EN = (
15
    "Represent this question for searching relevant passages: "
16
)
17
DEFAULT_QUERY_BGE_INSTRUCTION_ZH = "为这个句子生成表示以用于检索相关文章:"
18

19
BGE_MODELS = (
20
    "BAAI/bge-small-en",
21
    "BAAI/bge-small-en-v1.5",
22
    "BAAI/bge-base-en",
23
    "BAAI/bge-base-en-v1.5",
24
    "BAAI/bge-large-en",
25
    "BAAI/bge-large-en-v1.5",
26
    "BAAI/bge-small-zh",
27
    "BAAI/bge-small-zh-v1.5",
28
    "BAAI/bge-base-zh",
29
    "BAAI/bge-base-zh-v1.5",
30
    "BAAI/bge-large-zh",
31
    "BAAI/bge-large-zh-v1.5",
32
)
33
INSTRUCTOR_MODELS = (
34
    "hku-nlp/instructor-base",
35
    "hku-nlp/instructor-large",
36
    "hku-nlp/instructor-xl",
37
    "hkunlp/instructor-base",
38
    "hkunlp/instructor-large",
39
    "hkunlp/instructor-xl",
40
)
41

42

43
def get_query_instruct_for_model_name(model_name: Optional[str]) -> str:
44
    """Get query text instruction for a given model name."""
45
    if model_name in INSTRUCTOR_MODELS:
46
        return DEFAULT_QUERY_INSTRUCTION
47
    if model_name in BGE_MODELS:
48
        if "zh" in model_name:
49
            return DEFAULT_QUERY_BGE_INSTRUCTION_ZH
50
        return DEFAULT_QUERY_BGE_INSTRUCTION_EN
51
    return ""
52

53

54
def format_query(
55
    query: str, model_name: Optional[str], instruction: Optional[str] = None
56
) -> str:
57
    if instruction is None:
58
        instruction = get_query_instruct_for_model_name(model_name)
59
    # NOTE: strip() enables backdoor for defeating instruction prepend by
60
    # passing empty string
61
    return f"{instruction} {query}".strip()
62

63

64
def get_text_instruct_for_model_name(model_name: Optional[str]) -> str:
65
    """Get text instruction for a given model name."""
66
    return DEFAULT_EMBED_INSTRUCTION if model_name in INSTRUCTOR_MODELS else ""
67

68

69
def format_text(
70
    text: str, model_name: Optional[str], instruction: Optional[str] = None
71
) -> str:
72
    if instruction is None:
73
        instruction = get_text_instruct_for_model_name(model_name)
74
    # NOTE: strip() enables backdoor for defeating instruction prepend by
75
    # passing empty string
76
    return f"{instruction} {text}".strip()
77

78

79
def get_pooling_mode(model_name: Optional[str]) -> str:
80
    pooling_config_url = (
81
        f"https://huggingface.co/{model_name}/raw/main/1_Pooling/config.json"
82
    )
83

84
    try:
85
        response = requests.get(pooling_config_url)
86
        config_data = response.json()
87

88
        cls_token = config_data.get("pooling_mode_cls_token", False)
89
        mean_tokens = config_data.get("pooling_mode_mean_tokens", False)
90

91
        if mean_tokens:
92
            return "mean"
93
        elif cls_token:
94
            return "cls"
95
    except requests.exceptions.RequestException:
96
        print(
97
            "Warning: Pooling config file not found; pooling mode is defaulted to 'cls'."
98
        )
99
    return "cls"
100

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.