llama-index

Форк
0
104 строки · 3.1 Кб
1
"""Common utils for embeddings."""
2

3
import json
4
import re
5
import uuid
6
from typing import Dict, List, Tuple
7

8
from tqdm import tqdm
9

10
from llama_index.legacy.bridge.pydantic import BaseModel
11
from llama_index.legacy.llms.utils import LLM
12
from llama_index.legacy.schema import MetadataMode, TextNode
13

14

15
class EmbeddingQAFinetuneDataset(BaseModel):
16
    """Embedding QA Finetuning Dataset.
17

18
    Args:
19
        queries (Dict[str, str]): Dict id -> query.
20
        corpus (Dict[str, str]): Dict id -> string.
21
        relevant_docs (Dict[str, List[str]]): Dict query id -> list of doc ids.
22

23
    """
24

25
    queries: Dict[str, str]  # dict id -> query
26
    corpus: Dict[str, str]  # dict id -> string
27
    relevant_docs: Dict[str, List[str]]  # query id -> list of doc ids
28
    mode: str = "text"
29

30
    @property
31
    def query_docid_pairs(self) -> List[Tuple[str, List[str]]]:
32
        """Get query, relevant doc ids."""
33
        return [
34
            (query, self.relevant_docs[query_id])
35
            for query_id, query in self.queries.items()
36
        ]
37

38
    def save_json(self, path: str) -> None:
39
        """Save json."""
40
        with open(path, "w") as f:
41
            json.dump(self.dict(), f, indent=4)
42

43
    @classmethod
44
    def from_json(cls, path: str) -> "EmbeddingQAFinetuneDataset":
45
        """Load json."""
46
        with open(path) as f:
47
            data = json.load(f)
48
        return cls(**data)
49

50

51
DEFAULT_QA_GENERATE_PROMPT_TMPL = """\
52
Context information is below.
53

54
---------------------
55
{context_str}
56
---------------------
57

58
Given the context information and not prior knowledge.
59
generate only questions based on the below query.
60

61
You are a Teacher/ Professor. Your task is to setup \
62
{num_questions_per_chunk} questions for an upcoming \
63
quiz/examination. The questions should be diverse in nature \
64
across the document. Restrict the questions to the \
65
context information provided."
66
"""
67

68

69
# generate queries as a convenience function
70
def generate_qa_embedding_pairs(
71
    nodes: List[TextNode],
72
    llm: LLM,
73
    qa_generate_prompt_tmpl: str = DEFAULT_QA_GENERATE_PROMPT_TMPL,
74
    num_questions_per_chunk: int = 2,
75
) -> EmbeddingQAFinetuneDataset:
76
    """Generate examples given a set of nodes."""
77
    node_dict = {
78
        node.node_id: node.get_content(metadata_mode=MetadataMode.NONE)
79
        for node in nodes
80
    }
81

82
    queries = {}
83
    relevant_docs = {}
84
    for node_id, text in tqdm(node_dict.items()):
85
        query = qa_generate_prompt_tmpl.format(
86
            context_str=text, num_questions_per_chunk=num_questions_per_chunk
87
        )
88
        response = llm.complete(query)
89

90
        result = str(response).strip().split("\n")
91
        questions = [
92
            re.sub(r"^\d+[\).\s]", "", question).strip() for question in result
93
        ]
94
        questions = [question for question in questions if len(question) > 0]
95

96
        for question in questions:
97
            question_id = str(uuid.uuid4())
98
            queries[question_id] = question
99
            relevant_docs[question_id] = [node_id]
100

101
    # construct dataset
102
    return EmbeddingQAFinetuneDataset(
103
        queries=queries, corpus=node_dict, relevant_docs=relevant_docs
104
    )
105

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.