llama-index

Форк
0
373 строки · 14.2 Кб
1
import asyncio
2
import os
3
import shutil
4
from argparse import ArgumentParser
5
from glob import iglob
6
from pathlib import Path
7
from typing import Any, Callable, Dict, Optional, Union, cast
8

9
from llama_index.legacy import (
10
    Response,
11
    ServiceContext,
12
    SimpleDirectoryReader,
13
    VectorStoreIndex,
14
)
15
from llama_index.legacy.bridge.pydantic import BaseModel, Field, validator
16
from llama_index.legacy.chat_engine import CondenseQuestionChatEngine
17
from llama_index.legacy.core.response.schema import RESPONSE_TYPE, StreamingResponse
18
from llama_index.legacy.embeddings.base import BaseEmbedding
19
from llama_index.legacy.ingestion import IngestionPipeline
20
from llama_index.legacy.llms import LLM, OpenAI
21
from llama_index.legacy.query_engine import CustomQueryEngine
22
from llama_index.legacy.query_pipeline import FnComponent
23
from llama_index.legacy.query_pipeline.query import QueryPipeline
24
from llama_index.legacy.readers.base import BaseReader
25
from llama_index.legacy.response_synthesizers import CompactAndRefine
26
from llama_index.legacy.utils import get_cache_dir
27

28
RAG_HISTORY_FILE_NAME = "files_history.txt"
29

30

31
def default_ragcli_persist_dir() -> str:
32
    return str(Path(get_cache_dir()) / "rag_cli")
33

34

35
def query_input(query_str: Optional[str] = None) -> str:
36
    return query_str or ""
37

38

39
class QueryPipelineQueryEngine(CustomQueryEngine):
40
    query_pipeline: QueryPipeline = Field(
41
        description="Query Pipeline to use for Q&A.",
42
    )
43

44
    def custom_query(self, query_str: str) -> RESPONSE_TYPE:
45
        return self.query_pipeline.run(query_str=query_str)
46

47
    async def acustom_query(self, query_str: str) -> RESPONSE_TYPE:
48
        return await self.query_pipeline.arun(query_str=query_str)
49

50

51
class RagCLI(BaseModel):
52
    """
53
    CLI tool for chatting with output of a IngestionPipeline via a QueryPipeline.
54
    """
55

56
    ingestion_pipeline: IngestionPipeline = Field(
57
        description="Ingestion pipeline to run for RAG ingestion."
58
    )
59
    verbose: bool = Field(
60
        description="Whether to print out verbose information during execution.",
61
        default=False,
62
    )
63
    persist_dir: str = Field(
64
        description="Directory to persist ingestion pipeline.",
65
        default_factory=default_ragcli_persist_dir,
66
    )
67
    llm: LLM = Field(
68
        description="Language model to use for response generation.",
69
        default_factory=lambda: OpenAI(model="gpt-3.5-turbo", streaming=True),
70
    )
71
    query_pipeline: Optional[QueryPipeline] = Field(
72
        description="Query Pipeline to use for Q&A.",
73
        default=None,
74
    )
75
    chat_engine: Optional[CondenseQuestionChatEngine] = Field(
76
        description="Chat engine to use for chatting.",
77
        default_factory=None,
78
    )
79
    file_extractor: Optional[Dict[str, BaseReader]] = Field(
80
        description="File extractor to use for extracting text from files.",
81
        default=None,
82
    )
83

84
    class Config:
85
        arbitrary_types_allowed = True
86

87
    @validator("query_pipeline", always=True)
88
    def query_pipeline_from_ingestion_pipeline(
89
        cls, query_pipeline: Any, values: Dict[str, Any]
90
    ) -> Optional[QueryPipeline]:
91
        """
92
        If query_pipeline is not provided, create one from ingestion_pipeline.
93
        """
94
        if query_pipeline is not None:
95
            return query_pipeline
96

97
        ingestion_pipeline = cast(IngestionPipeline, values["ingestion_pipeline"])
98
        if ingestion_pipeline.vector_store is None:
99
            return None
100
        verbose = cast(bool, values["verbose"])
101
        query_component = FnComponent(
102
            fn=query_input, output_key="output", req_params={"query_str"}
103
        )
104
        llm = cast(LLM, values["llm"])
105

106
        # get embed_model from transformations if possible
107
        embed_model = None
108
        if ingestion_pipeline.transformations is not None:
109
            for transformation in ingestion_pipeline.transformations:
110
                if isinstance(transformation, BaseEmbedding):
111
                    embed_model = transformation
112
                    break
113

114
        service_context = ServiceContext.from_defaults(
115
            llm=llm, embed_model=embed_model or "default"
116
        )
117
        retriever = VectorStoreIndex.from_vector_store(
118
            ingestion_pipeline.vector_store, service_context=service_context
119
        ).as_retriever(similarity_top_k=8)
120
        response_synthesizer = CompactAndRefine(
121
            service_context=service_context, streaming=True, verbose=verbose
122
        )
123

124
        # define query pipeline
125
        query_pipeline = QueryPipeline(verbose=verbose)
126
        query_pipeline.add_modules(
127
            {
128
                "query": query_component,
129
                "retriever": retriever,
130
                "summarizer": response_synthesizer,
131
            }
132
        )
133
        query_pipeline.add_link("query", "retriever")
134
        query_pipeline.add_link("retriever", "summarizer", dest_key="nodes")
135
        query_pipeline.add_link("query", "summarizer", dest_key="query_str")
136
        return query_pipeline
137

138
    @validator("chat_engine", always=True)
139
    def chat_engine_from_query_pipeline(
140
        cls, chat_engine: Any, values: Dict[str, Any]
141
    ) -> Optional[CondenseQuestionChatEngine]:
142
        """
143
        If chat_engine is not provided, create one from query_pipeline.
144
        """
145
        if chat_engine is not None:
146
            return chat_engine
147

148
        if values.get("query_pipeline", None) is None:
149
            values["query_pipeline"] = cls.query_pipeline_from_ingestion_pipeline(
150
                query_pipeline=None, values=values
151
            )
152

153
        query_pipeline = cast(QueryPipeline, values["query_pipeline"])
154
        if query_pipeline is None:
155
            return None
156
        query_engine = QueryPipelineQueryEngine(query_pipeline=query_pipeline)  # type: ignore
157
        verbose = cast(bool, values["verbose"])
158
        llm = cast(LLM, values["llm"])
159
        return CondenseQuestionChatEngine.from_defaults(
160
            query_engine=query_engine, llm=llm, verbose=verbose
161
        )
162

163
    async def handle_cli(
164
        self,
165
        files: Optional[str] = None,
166
        question: Optional[str] = None,
167
        chat: bool = False,
168
        verbose: bool = False,
169
        clear: bool = False,
170
        create_llama: bool = False,
171
        **kwargs: Dict[str, Any],
172
    ) -> None:
173
        """
174
        Entrypoint for local document RAG CLI tool.
175
        """
176
        if clear:
177
            # delete self.persist_dir directory including all subdirectories and files
178
            if os.path.exists(self.persist_dir):
179
                # Ask for confirmation
180
                response = input(
181
                    f"Are you sure you want to delete data within {self.persist_dir}? [y/N] "
182
                )
183
                if response.strip().lower() != "y":
184
                    print("Aborted.")
185
                    return
186
                os.system(f"rm -rf {self.persist_dir}")
187
            print(f"Successfully cleared {self.persist_dir}")
188

189
        self.verbose = verbose
190
        ingestion_pipeline = cast(IngestionPipeline, self.ingestion_pipeline)
191
        if self.verbose:
192
            print("Saving/Loading from persist_dir: ", self.persist_dir)
193
        if files is not None:
194
            documents = []
195
            for _file in iglob(files, recursive=True):
196
                _file = os.path.abspath(_file)
197
                if os.path.isdir(_file):
198
                    reader = SimpleDirectoryReader(
199
                        input_dir=_file,
200
                        filename_as_id=True,
201
                        file_extractor=self.file_extractor,
202
                    )
203
                else:
204
                    reader = SimpleDirectoryReader(
205
                        input_files=[_file],
206
                        filename_as_id=True,
207
                        file_extractor=self.file_extractor,
208
                    )
209

210
                documents.extend(reader.load_data(show_progress=verbose))
211

212
            await ingestion_pipeline.arun(show_progress=verbose, documents=documents)
213
            ingestion_pipeline.persist(persist_dir=self.persist_dir)
214

215
            # Append the `--files` argument to the history file
216
            with open(f"{self.persist_dir}/{RAG_HISTORY_FILE_NAME}", "a") as f:
217
                f.write(files + "\n")
218

219
        if create_llama:
220
            if shutil.which("npx") is None:
221
                print(
222
                    "`npx` is not installed. Please install it by calling `npm install -g npx`"
223
                )
224
            else:
225
                history_file_path = Path(f"{self.persist_dir}/{RAG_HISTORY_FILE_NAME}")
226
                if not history_file_path.exists():
227
                    print(
228
                        "No data has been ingested, "
229
                        "please specify `--files` to create llama dataset."
230
                    )
231
                else:
232
                    with open(history_file_path) as f:
233
                        stored_paths = {line.strip() for line in f if line.strip()}
234
                    if len(stored_paths) == 0:
235
                        print(
236
                            "No data has been ingested, "
237
                            "please specify `--files` to create llama dataset."
238
                        )
239
                    elif len(stored_paths) > 1:
240
                        print(
241
                            "Multiple files or folders were ingested, which is not supported by create-llama. "
242
                            "Please call `llamaindex-cli rag --clear` to clear the cache first, "
243
                            "then call `llamaindex-cli rag --files` again with a single folder or file"
244
                        )
245
                    else:
246
                        path = stored_paths.pop()
247
                        if "*" in path:
248
                            print(
249
                                "Glob pattern is not supported by create-llama. "
250
                                "Please call `llamaindex-cli rag --clear` to clear the cache first, "
251
                                "then call `llamaindex-cli rag --files` again with a single folder or file."
252
                            )
253
                        elif not os.path.exists(path):
254
                            print(
255
                                f"The path {path} does not exist. "
256
                                "Please call `llamaindex-cli rag --clear` to clear the cache first, "
257
                                "then call `llamaindex-cli rag --files` again with a single folder or file."
258
                            )
259
                        else:
260
                            print(f"Calling create-llama using data from {path} ...")
261
                            command_args = [
262
                                "npx",
263
                                "create-llama@latest",
264
                                "--frontend",
265
                                "--template",
266
                                "streaming",
267
                                "--framework",
268
                                "fastapi",
269
                                "--ui",
270
                                "shadcn",
271
                                "--vector-db",
272
                                "none",
273
                                "--engine",
274
                                "context",
275
                                f"--files {path}",
276
                            ]
277
                            os.system(" ".join(command_args))
278

279
        if question is not None:
280
            await self.handle_question(question)
281
        if chat:
282
            await self.start_chat_repl()
283

284
    async def handle_question(self, question: str) -> None:
285
        if self.query_pipeline is None:
286
            raise ValueError("query_pipeline is not defined.")
287
        query_pipeline = cast(QueryPipeline, self.query_pipeline)
288
        query_pipeline.verbose = self.verbose
289
        chat_engine = cast(CondenseQuestionChatEngine, self.chat_engine)
290
        response = chat_engine.chat(question)
291

292
        if isinstance(response, StreamingResponse):
293
            response.print_response_stream()
294
        else:
295
            response = cast(Response, response)
296
            print(response)
297

298
    async def start_chat_repl(self) -> None:
299
        """
300
        Start a REPL for chatting with the agent.
301
        """
302
        if self.query_pipeline is None:
303
            raise ValueError("query_pipeline is not defined.")
304
        chat_engine = cast(CondenseQuestionChatEngine, self.chat_engine)
305
        chat_engine.streaming_chat_repl()
306

307
    @classmethod
308
    def add_parser_args(
309
        cls,
310
        parser: Union[ArgumentParser, Any],
311
        instance_generator: Callable[[], "RagCLI"],
312
    ) -> None:
313
        parser.add_argument(
314
            "-q",
315
            "--question",
316
            type=str,
317
            help="The question you want to ask.",
318
            required=False,
319
        )
320

321
        parser.add_argument(
322
            "-f",
323
            "--files",
324
            type=str,
325
            help=(
326
                "The name of the file or directory you want to ask a question about,"
327
                'such as "file.pdf".'
328
            ),
329
        )
330
        parser.add_argument(
331
            "-c",
332
            "--chat",
333
            help="If flag is present, opens a chat REPL.",
334
            action="store_true",
335
        )
336
        parser.add_argument(
337
            "-v",
338
            "--verbose",
339
            help="Whether to print out verbose information during execution.",
340
            action="store_true",
341
        )
342
        parser.add_argument(
343
            "--clear",
344
            help="Clears out all currently embedded data.",
345
            action="store_true",
346
        )
347
        parser.add_argument(
348
            "--create-llama",
349
            help="Create a LlamaIndex application with your embedded data.",
350
            required=False,
351
            action="store_true",
352
        )
353
        parser.set_defaults(
354
            func=lambda args: asyncio.run(instance_generator().handle_cli(**vars(args)))
355
        )
356

357
    def cli(self) -> None:
358
        """
359
        Entrypoint for CLI tool.
360
        """
361
        parser = ArgumentParser(description="LlamaIndex RAG Q&A tool.")
362
        subparsers = parser.add_subparsers(
363
            title="commands", dest="command", required=True
364
        )
365
        llamarag_parser = subparsers.add_parser(
366
            "rag", help="Ask a question to a document / a directory of documents."
367
        )
368
        self.add_parser_args(llamarag_parser, lambda: self)
369
        # Parse the command-line arguments
370
        args = parser.parse_args()
371

372
        # Call the appropriate function based on the command
373
        args.func(args)
374

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.