griptape

Форк
0
/
structure_tester.py 
320 строк · 13.6 Кб
1
from __future__ import annotations
2
import os
3
from attr import field, define
4
from schema import Schema, Literal
5
import logging
6
import json
7
from griptape.artifacts.error_artifact import ErrorArtifact
8

9
from griptape.structures import Agent
10
from griptape.rules import Rule, Ruleset
11
from griptape.tasks import PromptTask
12
from griptape.structures import Structure
13
from griptape.drivers import (
14
    BasePromptDriver,
15
    AmazonBedrockPromptDriver,
16
    AnthropicPromptDriver,
17
    BedrockClaudePromptModelDriver,
18
    BedrockJurassicPromptModelDriver,
19
    BedrockTitanPromptModelDriver,
20
    BedrockLlamaPromptModelDriver,
21
    CoherePromptDriver,
22
    OpenAiChatPromptDriver,
23
    OpenAiCompletionPromptDriver,
24
    AzureOpenAiChatPromptDriver,
25
    AmazonSageMakerPromptDriver,
26
    SageMakerLlamaPromptModelDriver,
27
    SageMakerFalconPromptModelDriver,
28
    GooglePromptDriver,
29
)
30

31

32
def get_enabled_prompt_drivers(prompt_drivers_options) -> list[BasePromptDriver]:
33
    return [
34
        prompt_driver_option.prompt_driver
35
        for prompt_driver_option in prompt_drivers_options
36
        if prompt_driver_option.enabled
37
    ]
38

39

40
@define
41
class StructureTester:
42
    @define
43
    class TesterPromptDriverOption:
44
        prompt_driver: BasePromptDriver = field()
45
        enabled: bool = field()
46

47
    PROMPT_DRIVERS = {
48
        "OPENAI_CHAT_35": TesterPromptDriverOption(
49
            prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"]),
50
            enabled=True,
51
        ),
52
        "OPENAI_CHAT_35_TURBO_1106": TesterPromptDriverOption(
53
            prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo-1106", api_key=os.environ["OPENAI_API_KEY"]),
54
            enabled=True,
55
        ),
56
        "OPENAI_CHAT_35_TURBO_INSTRUCT": TesterPromptDriverOption(
57
            prompt_driver=OpenAiCompletionPromptDriver(
58
                model="gpt-3.5-turbo-instruct", api_key=os.environ["OPENAI_API_KEY"]
59
            ),
60
            enabled=True,
61
        ),
62
        "OPENAI_CHAT_4": TesterPromptDriverOption(
63
            prompt_driver=OpenAiChatPromptDriver(model="gpt-4", api_key=os.environ["OPENAI_API_KEY"]), enabled=True
64
        ),
65
        "OPENAI_CHAT_4o": TesterPromptDriverOption(
66
            prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", api_key=os.environ["OPENAI_API_KEY"]), enabled=True
67
        ),
68
        "OPENAI_CHAT_4_1106_PREVIEW": TesterPromptDriverOption(
69
            prompt_driver=OpenAiChatPromptDriver(model="gpt-4-1106-preview", api_key=os.environ["OPENAI_API_KEY"]),
70
            enabled=True,
71
        ),
72
        "OPENAI_COMPLETION_DAVINCI": TesterPromptDriverOption(
73
            prompt_driver=OpenAiCompletionPromptDriver(api_key=os.environ["OPENAI_API_KEY"], model="text-davinci-003"),
74
            enabled=True,
75
        ),
76
        "AZURE_CHAT_35_TURBO": TesterPromptDriverOption(
77
            prompt_driver=AzureOpenAiChatPromptDriver(
78
                api_key=os.environ["AZURE_OPENAI_API_KEY_1"],
79
                model="gpt-35-turbo",
80
                azure_deployment=os.environ["AZURE_OPENAI_35_TURBO_DEPLOYMENT_ID"],
81
                azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"],
82
            ),
83
            enabled=True,
84
        ),
85
        "AZURE_CHAT_35_TURBO_16K": TesterPromptDriverOption(
86
            prompt_driver=AzureOpenAiChatPromptDriver(
87
                api_key=os.environ["AZURE_OPENAI_API_KEY_1"],
88
                model="gpt-35-turbo-16k",
89
                azure_deployment=os.environ["AZURE_OPENAI_35_TURBO_16K_DEPLOYMENT_ID"],
90
                azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"],
91
            ),
92
            enabled=True,
93
        ),
94
        "AZURE_CHAT_4": TesterPromptDriverOption(
95
            prompt_driver=AzureOpenAiChatPromptDriver(
96
                api_key=os.environ["AZURE_OPENAI_API_KEY_1"],
97
                model="gpt-4",
98
                azure_deployment=os.environ["AZURE_OPENAI_4_DEPLOYMENT_ID"],
99
                azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"],
100
            ),
101
            enabled=True,
102
        ),
103
        "AZURE_CHAT_4_32K": TesterPromptDriverOption(
104
            prompt_driver=AzureOpenAiChatPromptDriver(
105
                api_key=os.environ["AZURE_OPENAI_API_KEY_1"],
106
                model="gpt-4-32k",
107
                azure_deployment=os.environ["AZURE_OPENAI_4_32K_DEPLOYMENT_ID"],
108
                azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"],
109
            ),
110
            enabled=True,
111
        ),
112
        "ANTHROPIC_CLAUDE_2_INSTANT": TesterPromptDriverOption(
113
            prompt_driver=AnthropicPromptDriver(model="claude-instant-1.2", api_key=os.environ["ANTHROPIC_API_KEY"]),
114
            enabled=True,
115
        ),
116
        "ANTHROPIC_CLAUDE_2": TesterPromptDriverOption(
117
            prompt_driver=AnthropicPromptDriver(model="claude-2.0", api_key=os.environ["ANTHROPIC_API_KEY"]),
118
            enabled=True,
119
        ),
120
        "ANTHROPIC_CLAUDE_2.1": TesterPromptDriverOption(
121
            prompt_driver=AnthropicPromptDriver(model="claude-2.1", api_key=os.environ["ANTHROPIC_API_KEY"]),
122
            enabled=True,
123
        ),
124
        "ANTHROPIC_CLAUDE_3_OPUS": TesterPromptDriverOption(
125
            prompt_driver=AnthropicPromptDriver(
126
                model="claude-3-opus-20240229", api_key=os.environ["ANTHROPIC_API_KEY"]
127
            ),
128
            enabled=True,
129
        ),
130
        "ANTHROPIC_CLAUDE_3_SONNET": TesterPromptDriverOption(
131
            prompt_driver=AnthropicPromptDriver(
132
                model="claude-3-sonnet-20240229", api_key=os.environ["ANTHROPIC_API_KEY"]
133
            ),
134
            enabled=True,
135
        ),
136
        "ANTHROPIC_CLAUDE_3_HAIKU": TesterPromptDriverOption(
137
            prompt_driver=AnthropicPromptDriver(
138
                model="claude-3-haiku-20240307", api_key=os.environ["ANTHROPIC_API_KEY"]
139
            ),
140
            enabled=True,
141
        ),
142
        "COHERE_COMMAND": TesterPromptDriverOption(
143
            prompt_driver=CoherePromptDriver(model="command", api_key=os.environ["COHERE_API_KEY"]), enabled=True
144
        ),
145
        "BEDROCK_TITAN": TesterPromptDriverOption(
146
            prompt_driver=AmazonBedrockPromptDriver(
147
                model="amazon.titan-tg1-large", prompt_model_driver=BedrockTitanPromptModelDriver()
148
            ),
149
            enabled=True,
150
        ),
151
        "BEDROCK_CLAUDE_INSTANT": TesterPromptDriverOption(
152
            prompt_driver=AmazonBedrockPromptDriver(
153
                model="anthropic.claude-instant-v1", prompt_model_driver=BedrockClaudePromptModelDriver()
154
            ),
155
            enabled=True,
156
        ),
157
        "BEDROCK_CLAUDE_2": TesterPromptDriverOption(
158
            prompt_driver=AmazonBedrockPromptDriver(
159
                model="anthropic.claude-v2", prompt_model_driver=BedrockClaudePromptModelDriver()
160
            ),
161
            enabled=True,
162
        ),
163
        "BEDROCK_CLAUDE_2.1": TesterPromptDriverOption(
164
            prompt_driver=AmazonBedrockPromptDriver(
165
                model="anthropic.claude-v2:1", prompt_model_driver=BedrockClaudePromptModelDriver()
166
            ),
167
            enabled=True,
168
        ),
169
        "BEDROCK_CLAUDE_3_SONNET": TesterPromptDriverOption(
170
            prompt_driver=AmazonBedrockPromptDriver(
171
                model="anthropic.claude-3-sonnet-20240229-v1:0", prompt_model_driver=BedrockClaudePromptModelDriver()
172
            ),
173
            enabled=True,
174
        ),
175
        "BEDROCK_CLAUDE_3_HAIKU": TesterPromptDriverOption(
176
            prompt_driver=AmazonBedrockPromptDriver(
177
                model="anthropic.claude-3-haiku-20240307-v1:0", prompt_model_driver=BedrockClaudePromptModelDriver()
178
            ),
179
            enabled=True,
180
        ),
181
        "BEDROCK_J2": TesterPromptDriverOption(
182
            prompt_driver=AmazonBedrockPromptDriver(
183
                model="ai21.j2-ultra", prompt_model_driver=BedrockJurassicPromptModelDriver()
184
            ),
185
            enabled=True,
186
        ),
187
        "BEDROCK_LLAMA2_13B": TesterPromptDriverOption(
188
            prompt_driver=AmazonBedrockPromptDriver(
189
                model="meta.llama2-13b-chat-v1", prompt_model_driver=BedrockLlamaPromptModelDriver(), max_attempts=1
190
            ),
191
            enabled=True,
192
        ),
193
        "BEDROCK_LLAMA2_70B": TesterPromptDriverOption(
194
            prompt_driver=AmazonBedrockPromptDriver(
195
                model="meta.llama2-70b-chat-v1", prompt_model_driver=BedrockLlamaPromptModelDriver(), max_attempts=1
196
            ),
197
            enabled=True,
198
        ),
199
        "SAGEMAKER_LLAMA_7B": TesterPromptDriverOption(
200
            prompt_driver=AmazonSageMakerPromptDriver(
201
                model=os.environ["SAGEMAKER_LLAMA_ENDPOINT_NAME"],
202
                prompt_model_driver=SageMakerLlamaPromptModelDriver(max_tokens=4096),
203
            ),
204
            enabled=False,
205
        ),
206
        "SAGEMAKER_FALCON_7b": TesterPromptDriverOption(
207
            prompt_driver=AmazonSageMakerPromptDriver(
208
                model=os.environ["SAGEMAKER_FALCON_ENDPOINT_NAME"],
209
                prompt_model_driver=SageMakerFalconPromptModelDriver(),
210
            ),
211
            enabled=False,
212
        ),
213
        "GOOGLE_GEMINI_PRO": TesterPromptDriverOption(
214
            prompt_driver=GooglePromptDriver(model="gemini-pro", api_key=os.environ["GOOGLE_API_KEY"]), enabled=True
215
        ),
216
    }
217
    TOOLKIT_TASK_CAPABLE_PROMPT_DRIVERS = get_enabled_prompt_drivers(
218
        [
219
            PROMPT_DRIVERS["OPENAI_CHAT_4"],
220
            PROMPT_DRIVERS["OPENAI_CHAT_4_1106_PREVIEW"],
221
            PROMPT_DRIVERS["AZURE_CHAT_4"],
222
            PROMPT_DRIVERS["AZURE_CHAT_4_32K"],
223
            PROMPT_DRIVERS["ANTHROPIC_CLAUDE_3_OPUS"],
224
            PROMPT_DRIVERS["GOOGLE_GEMINI_PRO"],
225
        ]
226
    )
227
    TOOL_TASK_CAPABLE_PROMPT_DRIVERS = get_enabled_prompt_drivers(PROMPT_DRIVERS.values())
228
    PROMPT_TASK_CAPABLE_PROMPT_DRIVERS = get_enabled_prompt_drivers(PROMPT_DRIVERS.values())
229
    TEXT_SUMMARY_TASK_CAPABLE_PROMPT_DRIVERS = get_enabled_prompt_drivers(PROMPT_DRIVERS.values())
230
    TEXT_QUERY_TASK_CAPABLE_PROMPT_DRIVERS = get_enabled_prompt_drivers(PROMPT_DRIVERS.values())
231
    JSON_EXTRACTION_TASK_CAPABLE_PROMPT_DRIVERS = get_enabled_prompt_drivers(PROMPT_DRIVERS.values())
232
    CSV_EXTRACTION_TASK_CAPABLE_PROMPT_DRIVERS = get_enabled_prompt_drivers(PROMPT_DRIVERS.values())
233
    RULE_CAPABLE_PROMPT_DRIVERS = get_enabled_prompt_drivers(PROMPT_DRIVERS.values())
234

235
    structure: Structure = field()
236

237
    @classmethod
238
    def prompt_driver_id_fn(cls, prompt_driver) -> str:
239
        return f"{prompt_driver.__class__.__name__}-{prompt_driver.model}"
240

241
    def verify_structure_output(self, structure) -> dict:
242
        output_schema = Schema(
243
            {
244
                Literal("correct", description="Whether the output was correct or not."): bool,
245
                Literal(
246
                    "explanation", description="A brief explanation of why you felt the output was correct or not."
247
                ): str,
248
            }
249
        )
250
        task_names = [task.__class__.__name__ for task in structure.tasks]
251
        prompt = structure.input_task.input.to_text()
252
        actual = structure.output_task.output.to_text()
253
        rules = [rule.value for ruleset in structure.input_task.all_rulesets for rule in ruleset.rules]
254

255
        agent = Agent(
256
            rulesets=[
257
                Ruleset(
258
                    name="Formatting",
259
                    rules=[
260
                        Rule(
261
                            f"Output a json object matching this schema: {output_schema.json_schema('Output Schema')}."
262
                        )
263
                    ],
264
                ),
265
                Ruleset(
266
                    name="Context",
267
                    rules=[
268
                        Rule(
269
                            "Your objective is to determine whether an LLM generated an acceptable output for a given tasks, prompt, and rules."
270
                        ),
271
                        Rule("The output does not need to be perfect, but it should be acceptable"),
272
                        Rule("Do not make any assumptions about how the output should be formatted."),
273
                        Rule(
274
                            "Do not worry about the accuracy of the output, only that it is an appropriate response to the prompt."
275
                        ),
276
                    ],
277
                ),
278
            ],
279
            prompt_driver=AzureOpenAiChatPromptDriver(
280
                api_key=os.environ["AZURE_OPENAI_API_KEY_1"],
281
                model="gpt-4o",
282
                azure_deployment=os.environ["AZURE_OPENAI_4_DEPLOYMENT_ID"],
283
                azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"],
284
                response_format="json_object",
285
            ),
286
            tasks=[
287
                PromptTask(
288
                    "\nTasks: {{ task_names }}"
289
                    '\n{% if rules %}Rules: """{{ rules }}"""{% endif %}'
290
                    '\nPrompt: """{{ prompt }}"""'
291
                    '\nOutput: """{{ output }}"""',
292
                    context={
293
                        "prompt": prompt,
294
                        "output": actual,
295
                        "task_names": ", ".join(task_names),
296
                        "rules": ", ".join(rules),
297
                    },
298
                )
299
            ],
300
            logger_level=logging.DEBUG,
301
        )
302
        agent.logger.debug("Determining correctness of output.")
303
        result = json.loads(agent.run().output_task.output.to_text())
304
        explanation = result["explanation"]
305

306
        agent.logger.debug(explanation)
307

308
        return result
309

310
    def run(self, prompt, assert_correctness: bool = True) -> dict:
311
        result = self.structure.run(prompt)
312
        if isinstance(result.output_task.output, ErrorArtifact):
313
            verified_result = {"correct": False, "explanation": f"ErrorArtifact: {result.output_task.output.to_text()}"}
314
        else:
315
            verified_result = self.verify_structure_output(self.structure)
316

317
        if assert_correctness:
318
            assert verified_result["correct"]
319

320
        return verified_result
321

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.