llama-index

Форк
0
72 строки · 2.4 Кб
1
from typing import Any, Optional, Sequence
2

3
from llama_index.legacy.prompts import BasePromptTemplate
4
from llama_index.legacy.prompts.default_prompts import DEFAULT_SIMPLE_INPUT_PROMPT
5
from llama_index.legacy.prompts.mixin import PromptDictType
6
from llama_index.legacy.response_synthesizers.base import BaseSynthesizer
7
from llama_index.legacy.service_context import ServiceContext
8
from llama_index.legacy.types import RESPONSE_TEXT_TYPE
9

10

11
class Generation(BaseSynthesizer):
12
    def __init__(
13
        self,
14
        simple_template: Optional[BasePromptTemplate] = None,
15
        service_context: Optional[ServiceContext] = None,
16
        streaming: bool = False,
17
    ) -> None:
18
        super().__init__(service_context=service_context, streaming=streaming)
19
        self._input_prompt = simple_template or DEFAULT_SIMPLE_INPUT_PROMPT
20

21
    def _get_prompts(self) -> PromptDictType:
22
        """Get prompts."""
23
        return {"simple_template": self._input_prompt}
24

25
    def _update_prompts(self, prompts: PromptDictType) -> None:
26
        """Update prompts."""
27
        if "simple_template" in prompts:
28
            self._input_prompt = prompts["simple_template"]
29

30
    async def aget_response(
31
        self,
32
        query_str: str,
33
        text_chunks: Sequence[str],
34
        **response_kwargs: Any,
35
    ) -> RESPONSE_TEXT_TYPE:
36
        # NOTE: ignore text chunks and previous response
37
        del text_chunks
38

39
        if not self._streaming:
40
            return await self._service_context.llm.apredict(
41
                self._input_prompt,
42
                query_str=query_str,
43
                **response_kwargs,
44
            )
45
        else:
46
            return self._service_context.llm.stream(
47
                self._input_prompt,
48
                query_str=query_str,
49
                **response_kwargs,
50
            )
51

52
    def get_response(
53
        self,
54
        query_str: str,
55
        text_chunks: Sequence[str],
56
        **response_kwargs: Any,
57
    ) -> RESPONSE_TEXT_TYPE:
58
        # NOTE: ignore text chunks and previous response
59
        del text_chunks
60

61
        if not self._streaming:
62
            return self._service_context.llm.predict(
63
                self._input_prompt,
64
                query_str=query_str,
65
                **response_kwargs,
66
            )
67
        else:
68
            return self._service_context.llm.stream(
69
                self._input_prompt,
70
                query_str=query_str,
71
                **response_kwargs,
72
            )
73

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.