llama-index

Форк
0
327 строк · 11.7 Кб
1
"""Dataset generation from documents."""
2

3
from __future__ import annotations
4

5
import asyncio
6
import json
7
import re
8
import uuid
9
from typing import Coroutine, Dict, List, Tuple
10

11
from deprecated import deprecated
12

13
from llama_index.legacy import Document, ServiceContext, SummaryIndex
14
from llama_index.legacy.bridge.pydantic import BaseModel, Field
15
from llama_index.legacy.ingestion import run_transformations
16
from llama_index.legacy.postprocessor.node import KeywordNodePostprocessor
17
from llama_index.legacy.prompts.base import BasePromptTemplate, PromptTemplate
18
from llama_index.legacy.prompts.default_prompts import DEFAULT_TEXT_QA_PROMPT
19
from llama_index.legacy.prompts.mixin import (
20
    PromptDictType,
21
    PromptMixin,
22
    PromptMixinType,
23
)
24
from llama_index.legacy.schema import BaseNode, MetadataMode, NodeWithScore
25

26
DEFAULT_QUESTION_GENERATION_PROMPT = """\
27
Context information is below.
28
---------------------
29
{context_str}
30
---------------------
31
Given the context information and not prior knowledge.
32
generate only questions based on the below query.
33
{query_str}
34
"""
35

36

37
@deprecated(
38
    "Deprecated in favor of `LabelledRagDataset` which should be used instead.",
39
    action="always",
40
)
41
class QueryResponseDataset(BaseModel):
42
    """Query Response Dataset.
43

44
    The response can be empty if the dataset is generated from documents.
45

46
    Args:
47
        queries (Dict[str, str]): Query id -> query.
48
        responses (Dict[str, str]): Query id -> response.
49

50
    """
51

52
    queries: Dict[str, str] = Field(
53
        default_factory=dict, description="Query id -> query"
54
    )
55
    responses: Dict[str, str] = Field(
56
        default_factory=dict, description="Query id -> response"
57
    )
58

59
    @classmethod
60
    def from_qr_pairs(
61
        cls,
62
        qr_pairs: List[Tuple[str, str]],
63
    ) -> QueryResponseDataset:
64
        """Create from qr pairs."""
65
        # define ids as simple integers
66
        queries = {str(idx): query for idx, (query, _) in enumerate(qr_pairs)}
67
        responses = {str(idx): response for idx, (_, response) in enumerate(qr_pairs)}
68
        return cls(queries=queries, responses=responses)
69

70
    @property
71
    def qr_pairs(self) -> List[Tuple[str, str]]:
72
        """Get pairs."""
73
        # if query_id not in response, throw error
74
        for query_id in self.queries:
75
            if query_id not in self.responses:
76
                raise ValueError(f"Query id {query_id} not in responses")
77

78
        return [
79
            (self.queries[query_id], self.responses[query_id])
80
            for query_id in self.queries
81
        ]
82

83
    @property
84
    def questions(self) -> List[str]:
85
        """Get questions."""
86
        return list(self.queries.values())
87

88
    def save_json(self, path: str) -> None:
89
        """Save json."""
90
        with open(path, "w") as f:
91
            json.dump(self.dict(), f, indent=4)
92

93
    @classmethod
94
    def from_json(cls, path: str) -> QueryResponseDataset:
95
        """Load json."""
96
        with open(path) as f:
97
            data = json.load(f)
98
        return cls(**data)
99

100

101
@deprecated(
102
    "Deprecated in favor of `RagDatasetGenerator` which should be used instead.",
103
    action="always",
104
)
105
class DatasetGenerator(PromptMixin):
106
    """Generate dataset (question/ question-answer pairs) \
107
    based on the given documents.
108

109
    NOTE: this is a beta feature, subject to change!
110

111
    Args:
112
        nodes (List[Node]): List of nodes. (Optional)
113
        service_context (ServiceContext): Service Context.
114
        num_questions_per_chunk: number of question to be \
115
        generated per chunk. Each document is chunked of size 512 words.
116
        text_question_template: Question generation template.
117
        question_gen_query: Question generation query.
118

119
    """
120

121
    def __init__(
122
        self,
123
        nodes: List[BaseNode],
124
        service_context: ServiceContext | None = None,
125
        num_questions_per_chunk: int = 10,
126
        text_question_template: BasePromptTemplate | None = None,
127
        text_qa_template: BasePromptTemplate | None = None,
128
        question_gen_query: str | None = None,
129
        metadata_mode: MetadataMode = MetadataMode.NONE,
130
        show_progress: bool = False,
131
    ) -> None:
132
        """Init params."""
133
        if service_context is None:
134
            service_context = service_context or ServiceContext.from_defaults(
135
                chunk_size_limit=3000
136
            )
137
        self.service_context = service_context
138
        self.text_question_template = text_question_template or PromptTemplate(
139
            DEFAULT_QUESTION_GENERATION_PROMPT
140
        )
141
        self.text_qa_template = text_qa_template or DEFAULT_TEXT_QA_PROMPT
142
        self.question_gen_query = (
143
            question_gen_query
144
            or f"You are a Teacher/Professor. Your task is to setup \
145
                        {num_questions_per_chunk} questions for an upcoming \
146
                        quiz/examination. The questions should be diverse in nature \
147
                            across the document. Restrict the questions to the \
148
                                context information provided."
149
        )
150
        self.nodes = nodes
151
        self._metadata_mode = metadata_mode
152
        self._show_progress = show_progress
153

154
    @classmethod
155
    def from_documents(
156
        cls,
157
        documents: List[Document],
158
        service_context: ServiceContext | None = None,
159
        num_questions_per_chunk: int = 10,
160
        text_question_template: BasePromptTemplate | None = None,
161
        text_qa_template: BasePromptTemplate | None = None,
162
        question_gen_query: str | None = None,
163
        required_keywords: List[str] | None = None,
164
        exclude_keywords: List[str] | None = None,
165
        show_progress: bool = False,
166
    ) -> DatasetGenerator:
167
        """Generate dataset from documents."""
168
        if service_context is None:
169
            service_context = service_context or ServiceContext.from_defaults(
170
                chunk_size_limit=3000
171
            )
172

173
        nodes = run_transformations(
174
            documents, service_context.transformations, show_progress=show_progress
175
        )
176

177
        # use node postprocessor to filter nodes
178
        required_keywords = required_keywords or []
179
        exclude_keywords = exclude_keywords or []
180
        node_postprocessor = KeywordNodePostprocessor(
181
            service_context=service_context,
182
            required_keywords=required_keywords,
183
            exclude_keywords=exclude_keywords,
184
        )
185
        node_with_scores = [NodeWithScore(node=node) for node in nodes]
186
        node_with_scores = node_postprocessor.postprocess_nodes(node_with_scores)
187
        nodes = [node_with_score.node for node_with_score in node_with_scores]
188

189
        return cls(
190
            nodes=nodes,
191
            service_context=service_context,
192
            num_questions_per_chunk=num_questions_per_chunk,
193
            text_question_template=text_question_template,
194
            text_qa_template=text_qa_template,
195
            question_gen_query=question_gen_query,
196
            show_progress=show_progress,
197
        )
198

199
    async def _agenerate_dataset(
200
        self,
201
        nodes: List[BaseNode],
202
        num: int | None = None,
203
        generate_response: bool = False,
204
    ) -> QueryResponseDataset:
205
        """Node question generator."""
206
        query_tasks: List[Coroutine] = []
207
        queries: Dict[str, str] = {}
208
        responses_dict: Dict[str, str] = {}
209

210
        if self._show_progress:
211
            from tqdm.asyncio import tqdm_asyncio
212

213
            async_module = tqdm_asyncio
214
        else:
215
            async_module = asyncio
216

217
        summary_indices: List[SummaryIndex] = []
218
        for node in nodes:
219
            if num is not None and len(query_tasks) >= num:
220
                break
221
            index = SummaryIndex.from_documents(
222
                [
223
                    Document(
224
                        text=node.get_content(metadata_mode=self._metadata_mode),
225
                        metadata=node.metadata,
226
                    )
227
                ],
228
                service_context=self.service_context,
229
            )
230

231
            query_engine = index.as_query_engine(
232
                service_context=self.service_context,
233
                text_qa_template=self.text_question_template,
234
                use_async=True,
235
            )
236
            task = query_engine.aquery(
237
                self.question_gen_query,
238
            )
239
            query_tasks.append(task)
240
            summary_indices.append(index)
241

242
        responses = await async_module.gather(*query_tasks)
243
        for idx, response in enumerate(responses):
244
            result = str(response).strip().split("\n")
245
            cleaned_questions = [
246
                re.sub(r"^\d+[\).\s]", "", question).strip() for question in result
247
            ]
248
            cleaned_questions = [
249
                question for question in cleaned_questions if len(question) > 0
250
            ]
251
            cur_queries = {
252
                str(uuid.uuid4()): question for question in cleaned_questions
253
            }
254
            queries.update(cur_queries)
255

256
            if generate_response:
257
                index = summary_indices[idx]
258
                qr_tasks = []
259
                cur_query_items = list(cur_queries.items())
260
                cur_query_keys = [query_id for query_id, _ in cur_query_items]
261
                for query_id, query in cur_query_items:
262
                    qa_query_engine = index.as_query_engine(
263
                        service_context=self.service_context,
264
                        text_qa_template=self.text_qa_template,
265
                    )
266
                    qr_task = qa_query_engine.aquery(query)
267
                    qr_tasks.append(qr_task)
268
                qr_responses = await async_module.gather(*qr_tasks)
269
                for query_id, qa_response in zip(cur_query_keys, qr_responses):
270
                    responses_dict[query_id] = str(qa_response)
271
            else:
272
                pass
273

274
        query_ids = list(queries.keys())
275
        if num is not None:
276
            query_ids = query_ids[:num]
277
            # truncate queries, responses to the subset of query ids
278
            queries = {query_id: queries[query_id] for query_id in query_ids}
279
            if generate_response:
280
                responses_dict = {
281
                    query_id: responses_dict[query_id] for query_id in query_ids
282
                }
283

284
        return QueryResponseDataset(queries=queries, responses=responses_dict)
285

286
    async def agenerate_questions_from_nodes(self, num: int | None = None) -> List[str]:
287
        """Generates questions for each document."""
288
        dataset = await self._agenerate_dataset(
289
            self.nodes, num=num, generate_response=False
290
        )
291
        return dataset.questions
292

293
    async def agenerate_dataset_from_nodes(
294
        self, num: int | None = None
295
    ) -> QueryResponseDataset:
296
        """Generates questions for each document."""
297
        return await self._agenerate_dataset(
298
            self.nodes, num=num, generate_response=True
299
        )
300

301
    def generate_questions_from_nodes(self, num: int | None = None) -> List[str]:
302
        """Generates questions for each document."""
303
        return asyncio.run(self.agenerate_questions_from_nodes(num=num))
304

305
    def generate_dataset_from_nodes(
306
        self, num: int | None = None
307
    ) -> QueryResponseDataset:
308
        """Generates questions for each document."""
309
        return asyncio.run(self.agenerate_dataset_from_nodes(num=num))
310

311
    def _get_prompts(self) -> PromptDictType:
312
        """Get prompts."""
313
        return {
314
            "text_question_template": self.text_question_template,
315
            "text_qa_template": self.text_qa_template,
316
        }
317

318
    def _get_prompt_modules(self) -> PromptMixinType:
319
        """Get prompt modules."""
320
        return {}
321

322
    def _update_prompts(self, prompts: PromptDictType) -> None:
323
        """Update prompts."""
324
        if "text_question_template" in prompts:
325
            self.text_question_template = prompts["text_question_template"]
326
        if "text_qa_template" in prompts:
327
            self.text_qa_template = prompts["text_qa_template"]
328

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.