haystack

Форк
0
/
test_rag_pipelines_e2e.py 
158 строк · 6.5 Кб
1
import os
2
import json
3
import pytest
4

5
from haystack import Pipeline, Document
6
from haystack.document_stores.in_memory import InMemoryDocumentStore
7
from haystack.components.writers import DocumentWriter
8
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
9
from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder
10
from haystack.components.generators import OpenAIGenerator
11
from haystack.components.builders.answer_builder import AnswerBuilder
12
from haystack.components.builders.prompt_builder import PromptBuilder
13

14

15
@pytest.mark.skipif(
16
    not os.environ.get("OPENAI_API_KEY", None),
17
    reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
18
)
19
def test_bm25_rag_pipeline(tmp_path):
20
    # Create the RAG pipeline
21
    prompt_template = """
22
    Given these documents, answer the question.\nDocuments:
23
    {% for doc in documents %}
24
        {{ doc.content }}
25
    {% endfor %}
26

27
    \nQuestion: {{question}}
28
    \nAnswer:
29
    """
30
    rag_pipeline = Pipeline()
31
    rag_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever")
32
    rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder")
33
    rag_pipeline.add_component(instance=OpenAIGenerator(), name="llm")
34
    rag_pipeline.add_component(instance=AnswerBuilder(), name="answer_builder")
35
    rag_pipeline.connect("retriever", "prompt_builder.documents")
36
    rag_pipeline.connect("prompt_builder", "llm")
37
    rag_pipeline.connect("llm.replies", "answer_builder.replies")
38
    rag_pipeline.connect("llm.meta", "answer_builder.meta")
39
    rag_pipeline.connect("retriever", "answer_builder.documents")
40

41
    # Draw the pipeline
42
    rag_pipeline.draw(tmp_path / "test_bm25_rag_pipeline.png")
43

44
    # Serialize the pipeline to YAML
45
    with open(tmp_path / "test_bm25_rag_pipeline.yaml", "w") as f:
46
        rag_pipeline.dump(f)
47

48
    # Load the pipeline back
49
    with open(tmp_path / "test_bm25_rag_pipeline.yaml", "r") as f:
50
        rag_pipeline = Pipeline.load(f)
51

52
    # Populate the document store
53
    documents = [
54
        Document(content="My name is Jean and I live in Paris."),
55
        Document(content="My name is Mark and I live in Berlin."),
56
        Document(content="My name is Giorgio and I live in Rome."),
57
    ]
58
    rag_pipeline.get_component("retriever").document_store.write_documents(documents)
59

60
    # Query and assert
61
    questions = ["Who lives in Paris?", "Who lives in Berlin?", "Who lives in Rome?"]
62
    answers_spywords = ["Jean", "Mark", "Giorgio"]
63

64
    for question, spyword in zip(questions, answers_spywords):
65
        result = rag_pipeline.run(
66
            {
67
                "retriever": {"query": question},
68
                "prompt_builder": {"question": question},
69
                "answer_builder": {"query": question},
70
            }
71
        )
72

73
        assert len(result["answer_builder"]["answers"]) == 1
74
        generated_answer = result["answer_builder"]["answers"][0]
75
        assert spyword in generated_answer.data
76
        assert generated_answer.query == question
77
        assert hasattr(generated_answer, "documents")
78
        assert hasattr(generated_answer, "meta")
79

80

81
@pytest.mark.skipif(
82
    not os.environ.get("OPENAI_API_KEY", None),
83
    reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
84
)
85
def test_embedding_retrieval_rag_pipeline(tmp_path):
86
    # Create the RAG pipeline
87
    prompt_template = """
88
    Given these documents, answer the question.\nDocuments:
89
    {% for doc in documents %}
90
        {{ doc.content }}
91
    {% endfor %}
92

93
    \nQuestion: {{question}}
94
    \nAnswer:
95
    """
96
    rag_pipeline = Pipeline()
97
    rag_pipeline.add_component(
98
        instance=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"), name="text_embedder"
99
    )
100
    rag_pipeline.add_component(
101
        instance=InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore()), name="retriever"
102
    )
103
    rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder")
104
    rag_pipeline.add_component(instance=OpenAIGenerator(), name="llm")
105
    rag_pipeline.add_component(instance=AnswerBuilder(), name="answer_builder")
106
    rag_pipeline.connect("text_embedder", "retriever")
107
    rag_pipeline.connect("retriever", "prompt_builder.documents")
108
    rag_pipeline.connect("prompt_builder", "llm")
109
    rag_pipeline.connect("llm.replies", "answer_builder.replies")
110
    rag_pipeline.connect("llm.meta", "answer_builder.meta")
111
    rag_pipeline.connect("retriever", "answer_builder.documents")
112

113
    # Draw the pipeline
114
    rag_pipeline.draw(tmp_path / "test_embedding_rag_pipeline.png")
115

116
    # Serialize the pipeline to JSON
117
    with open(tmp_path / "test_embedding_rag_pipeline.json", "w") as f:
118
        json.dump(rag_pipeline.to_dict(), f)
119

120
    # Load the pipeline back
121
    with open(tmp_path / "test_embedding_rag_pipeline.json", "r") as f:
122
        rag_pipeline = Pipeline.from_dict(json.load(f))
123

124
    # Populate the document store
125
    documents = [
126
        Document(content="My name is Jean and I live in Paris."),
127
        Document(content="My name is Mark and I live in Berlin."),
128
        Document(content="My name is Giorgio and I live in Rome."),
129
    ]
130
    document_store = rag_pipeline.get_component("retriever").document_store
131
    indexing_pipeline = Pipeline()
132
    indexing_pipeline.add_component(
133
        instance=SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"),
134
        name="document_embedder",
135
    )
136
    indexing_pipeline.add_component(instance=DocumentWriter(document_store=document_store), name="document_writer")
137
    indexing_pipeline.connect("document_embedder", "document_writer")
138
    indexing_pipeline.run({"document_embedder": {"documents": documents}})
139

140
    # Query and assert
141
    questions = ["Who lives in Paris?", "Who lives in Berlin?", "Who lives in Rome?"]
142
    answers_spywords = ["Jean", "Mark", "Giorgio"]
143

144
    for question, spyword in zip(questions, answers_spywords):
145
        result = rag_pipeline.run(
146
            {
147
                "text_embedder": {"text": question},
148
                "prompt_builder": {"question": question},
149
                "answer_builder": {"query": question},
150
            }
151
        )
152

153
        assert len(result["answer_builder"]["answers"]) == 1
154
        generated_answer = result["answer_builder"]["answers"][0]
155
        assert spyword in generated_answer.data
156
        assert generated_answer.query == question
157
        assert hasattr(generated_answer, "documents")
158
        assert hasattr(generated_answer, "meta")
159

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.