llama-index
327 строк · 11.7 Кб
1"""Dataset generation from documents."""
2
3from __future__ import annotations4
5import asyncio6import json7import re8import uuid9from typing import Coroutine, Dict, List, Tuple10
11from deprecated import deprecated12
13from llama_index.legacy import Document, ServiceContext, SummaryIndex14from llama_index.legacy.bridge.pydantic import BaseModel, Field15from llama_index.legacy.ingestion import run_transformations16from llama_index.legacy.postprocessor.node import KeywordNodePostprocessor17from llama_index.legacy.prompts.base import BasePromptTemplate, PromptTemplate18from llama_index.legacy.prompts.default_prompts import DEFAULT_TEXT_QA_PROMPT19from llama_index.legacy.prompts.mixin import (20PromptDictType,21PromptMixin,22PromptMixinType,23)
24from llama_index.legacy.schema import BaseNode, MetadataMode, NodeWithScore25
26DEFAULT_QUESTION_GENERATION_PROMPT = """\27Context information is below.
28---------------------
29{context_str}
30---------------------
31Given the context information and not prior knowledge.
32generate only questions based on the below query.
33{query_str}
34"""
35
36
37@deprecated(38"Deprecated in favor of `LabelledRagDataset` which should be used instead.",39action="always",40)
41class QueryResponseDataset(BaseModel):42"""Query Response Dataset.43
44The response can be empty if the dataset is generated from documents.
45
46Args:
47queries (Dict[str, str]): Query id -> query.
48responses (Dict[str, str]): Query id -> response.
49
50"""
51
52queries: Dict[str, str] = Field(53default_factory=dict, description="Query id -> query"54)55responses: Dict[str, str] = Field(56default_factory=dict, description="Query id -> response"57)58
59@classmethod60def from_qr_pairs(61cls,62qr_pairs: List[Tuple[str, str]],63) -> QueryResponseDataset:64"""Create from qr pairs."""65# define ids as simple integers66queries = {str(idx): query for idx, (query, _) in enumerate(qr_pairs)}67responses = {str(idx): response for idx, (_, response) in enumerate(qr_pairs)}68return cls(queries=queries, responses=responses)69
70@property71def qr_pairs(self) -> List[Tuple[str, str]]:72"""Get pairs."""73# if query_id not in response, throw error74for query_id in self.queries:75if query_id not in self.responses:76raise ValueError(f"Query id {query_id} not in responses")77
78return [79(self.queries[query_id], self.responses[query_id])80for query_id in self.queries81]82
83@property84def questions(self) -> List[str]:85"""Get questions."""86return list(self.queries.values())87
88def save_json(self, path: str) -> None:89"""Save json."""90with open(path, "w") as f:91json.dump(self.dict(), f, indent=4)92
93@classmethod94def from_json(cls, path: str) -> QueryResponseDataset:95"""Load json."""96with open(path) as f:97data = json.load(f)98return cls(**data)99
100
101@deprecated(102"Deprecated in favor of `RagDatasetGenerator` which should be used instead.",103action="always",104)
105class DatasetGenerator(PromptMixin):106"""Generate dataset (question/ question-answer pairs) \107based on the given documents.
108
109NOTE: this is a beta feature, subject to change!
110
111Args:
112nodes (List[Node]): List of nodes. (Optional)
113service_context (ServiceContext): Service Context.
114num_questions_per_chunk: number of question to be \
115generated per chunk. Each document is chunked of size 512 words.
116text_question_template: Question generation template.
117question_gen_query: Question generation query.
118
119"""
120
121def __init__(122self,123nodes: List[BaseNode],124service_context: ServiceContext | None = None,125num_questions_per_chunk: int = 10,126text_question_template: BasePromptTemplate | None = None,127text_qa_template: BasePromptTemplate | None = None,128question_gen_query: str | None = None,129metadata_mode: MetadataMode = MetadataMode.NONE,130show_progress: bool = False,131) -> None:132"""Init params."""133if service_context is None:134service_context = service_context or ServiceContext.from_defaults(135chunk_size_limit=3000136)137self.service_context = service_context138self.text_question_template = text_question_template or PromptTemplate(139DEFAULT_QUESTION_GENERATION_PROMPT
140)141self.text_qa_template = text_qa_template or DEFAULT_TEXT_QA_PROMPT142self.question_gen_query = (143question_gen_query
144or f"You are a Teacher/Professor. Your task is to setup \145{num_questions_per_chunk} questions for an upcoming \146quiz/examination. The questions should be diverse in nature \147across the document. Restrict the questions to the \148context information provided."149)150self.nodes = nodes151self._metadata_mode = metadata_mode152self._show_progress = show_progress153
154@classmethod155def from_documents(156cls,157documents: List[Document],158service_context: ServiceContext | None = None,159num_questions_per_chunk: int = 10,160text_question_template: BasePromptTemplate | None = None,161text_qa_template: BasePromptTemplate | None = None,162question_gen_query: str | None = None,163required_keywords: List[str] | None = None,164exclude_keywords: List[str] | None = None,165show_progress: bool = False,166) -> DatasetGenerator:167"""Generate dataset from documents."""168if service_context is None:169service_context = service_context or ServiceContext.from_defaults(170chunk_size_limit=3000171)172
173nodes = run_transformations(174documents, service_context.transformations, show_progress=show_progress175)176
177# use node postprocessor to filter nodes178required_keywords = required_keywords or []179exclude_keywords = exclude_keywords or []180node_postprocessor = KeywordNodePostprocessor(181service_context=service_context,182required_keywords=required_keywords,183exclude_keywords=exclude_keywords,184)185node_with_scores = [NodeWithScore(node=node) for node in nodes]186node_with_scores = node_postprocessor.postprocess_nodes(node_with_scores)187nodes = [node_with_score.node for node_with_score in node_with_scores]188
189return cls(190nodes=nodes,191service_context=service_context,192num_questions_per_chunk=num_questions_per_chunk,193text_question_template=text_question_template,194text_qa_template=text_qa_template,195question_gen_query=question_gen_query,196show_progress=show_progress,197)198
199async def _agenerate_dataset(200self,201nodes: List[BaseNode],202num: int | None = None,203generate_response: bool = False,204) -> QueryResponseDataset:205"""Node question generator."""206query_tasks: List[Coroutine] = []207queries: Dict[str, str] = {}208responses_dict: Dict[str, str] = {}209
210if self._show_progress:211from tqdm.asyncio import tqdm_asyncio212
213async_module = tqdm_asyncio214else:215async_module = asyncio216
217summary_indices: List[SummaryIndex] = []218for node in nodes:219if num is not None and len(query_tasks) >= num:220break221index = SummaryIndex.from_documents(222[223Document(224text=node.get_content(metadata_mode=self._metadata_mode),225metadata=node.metadata,226)227],228service_context=self.service_context,229)230
231query_engine = index.as_query_engine(232service_context=self.service_context,233text_qa_template=self.text_question_template,234use_async=True,235)236task = query_engine.aquery(237self.question_gen_query,238)239query_tasks.append(task)240summary_indices.append(index)241
242responses = await async_module.gather(*query_tasks)243for idx, response in enumerate(responses):244result = str(response).strip().split("\n")245cleaned_questions = [246re.sub(r"^\d+[\).\s]", "", question).strip() for question in result247]248cleaned_questions = [249question for question in cleaned_questions if len(question) > 0250]251cur_queries = {252str(uuid.uuid4()): question for question in cleaned_questions253}254queries.update(cur_queries)255
256if generate_response:257index = summary_indices[idx]258qr_tasks = []259cur_query_items = list(cur_queries.items())260cur_query_keys = [query_id for query_id, _ in cur_query_items]261for query_id, query in cur_query_items:262qa_query_engine = index.as_query_engine(263service_context=self.service_context,264text_qa_template=self.text_qa_template,265)266qr_task = qa_query_engine.aquery(query)267qr_tasks.append(qr_task)268qr_responses = await async_module.gather(*qr_tasks)269for query_id, qa_response in zip(cur_query_keys, qr_responses):270responses_dict[query_id] = str(qa_response)271else:272pass273
274query_ids = list(queries.keys())275if num is not None:276query_ids = query_ids[:num]277# truncate queries, responses to the subset of query ids278queries = {query_id: queries[query_id] for query_id in query_ids}279if generate_response:280responses_dict = {281query_id: responses_dict[query_id] for query_id in query_ids282}283
284return QueryResponseDataset(queries=queries, responses=responses_dict)285
286async def agenerate_questions_from_nodes(self, num: int | None = None) -> List[str]:287"""Generates questions for each document."""288dataset = await self._agenerate_dataset(289self.nodes, num=num, generate_response=False290)291return dataset.questions292
293async def agenerate_dataset_from_nodes(294self, num: int | None = None295) -> QueryResponseDataset:296"""Generates questions for each document."""297return await self._agenerate_dataset(298self.nodes, num=num, generate_response=True299)300
301def generate_questions_from_nodes(self, num: int | None = None) -> List[str]:302"""Generates questions for each document."""303return asyncio.run(self.agenerate_questions_from_nodes(num=num))304
305def generate_dataset_from_nodes(306self, num: int | None = None307) -> QueryResponseDataset:308"""Generates questions for each document."""309return asyncio.run(self.agenerate_dataset_from_nodes(num=num))310
311def _get_prompts(self) -> PromptDictType:312"""Get prompts."""313return {314"text_question_template": self.text_question_template,315"text_qa_template": self.text_qa_template,316}317
318def _get_prompt_modules(self) -> PromptMixinType:319"""Get prompt modules."""320return {}321
322def _update_prompts(self, prompts: PromptDictType) -> None:323"""Update prompts."""324if "text_question_template" in prompts:325self.text_question_template = prompts["text_question_template"]326if "text_qa_template" in prompts:327self.text_qa_template = prompts["text_qa_template"]328