1
from __future__ import annotations
3
from attr import field, define
4
from schema import Schema, Literal
7
from griptape.artifacts.error_artifact import ErrorArtifact
9
from griptape.structures import Agent
10
from griptape.rules import Rule, Ruleset
11
from griptape.tasks import PromptTask
12
from griptape.structures import Structure
13
from griptape.drivers import (
15
AmazonBedrockPromptDriver,
16
AnthropicPromptDriver,
17
BedrockClaudePromptModelDriver,
18
BedrockJurassicPromptModelDriver,
19
BedrockTitanPromptModelDriver,
20
BedrockLlamaPromptModelDriver,
22
OpenAiChatPromptDriver,
23
OpenAiCompletionPromptDriver,
24
AzureOpenAiChatPromptDriver,
25
AmazonSageMakerPromptDriver,
26
SageMakerLlamaPromptModelDriver,
27
SageMakerFalconPromptModelDriver,
32
def get_enabled_prompt_drivers(prompt_drivers_options) -> list[BasePromptDriver]:
34
prompt_driver_option.prompt_driver
35
for prompt_driver_option in prompt_drivers_options
36
if prompt_driver_option.enabled
43
class TesterPromptDriverOption:
44
prompt_driver: BasePromptDriver = field()
45
enabled: bool = field()
48
"OPENAI_CHAT_35": TesterPromptDriverOption(
49
prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"]),
52
"OPENAI_CHAT_35_TURBO_1106": TesterPromptDriverOption(
53
prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo-1106", api_key=os.environ["OPENAI_API_KEY"]),
56
"OPENAI_CHAT_35_TURBO_INSTRUCT": TesterPromptDriverOption(
57
prompt_driver=OpenAiCompletionPromptDriver(
58
model="gpt-3.5-turbo-instruct", api_key=os.environ["OPENAI_API_KEY"]
62
"OPENAI_CHAT_4": TesterPromptDriverOption(
63
prompt_driver=OpenAiChatPromptDriver(model="gpt-4", api_key=os.environ["OPENAI_API_KEY"]), enabled=True
65
"OPENAI_CHAT_4o": TesterPromptDriverOption(
66
prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", api_key=os.environ["OPENAI_API_KEY"]), enabled=True
68
"OPENAI_CHAT_4_1106_PREVIEW": TesterPromptDriverOption(
69
prompt_driver=OpenAiChatPromptDriver(model="gpt-4-1106-preview", api_key=os.environ["OPENAI_API_KEY"]),
72
"OPENAI_COMPLETION_DAVINCI": TesterPromptDriverOption(
73
prompt_driver=OpenAiCompletionPromptDriver(api_key=os.environ["OPENAI_API_KEY"], model="text-davinci-003"),
76
"AZURE_CHAT_35_TURBO": TesterPromptDriverOption(
77
prompt_driver=AzureOpenAiChatPromptDriver(
78
api_key=os.environ["AZURE_OPENAI_API_KEY_1"],
80
azure_deployment=os.environ["AZURE_OPENAI_35_TURBO_DEPLOYMENT_ID"],
81
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"],
85
"AZURE_CHAT_35_TURBO_16K": TesterPromptDriverOption(
86
prompt_driver=AzureOpenAiChatPromptDriver(
87
api_key=os.environ["AZURE_OPENAI_API_KEY_1"],
88
model="gpt-35-turbo-16k",
89
azure_deployment=os.environ["AZURE_OPENAI_35_TURBO_16K_DEPLOYMENT_ID"],
90
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"],
94
"AZURE_CHAT_4": TesterPromptDriverOption(
95
prompt_driver=AzureOpenAiChatPromptDriver(
96
api_key=os.environ["AZURE_OPENAI_API_KEY_1"],
98
azure_deployment=os.environ["AZURE_OPENAI_4_DEPLOYMENT_ID"],
99
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"],
103
"AZURE_CHAT_4_32K": TesterPromptDriverOption(
104
prompt_driver=AzureOpenAiChatPromptDriver(
105
api_key=os.environ["AZURE_OPENAI_API_KEY_1"],
107
azure_deployment=os.environ["AZURE_OPENAI_4_32K_DEPLOYMENT_ID"],
108
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"],
112
"ANTHROPIC_CLAUDE_2_INSTANT": TesterPromptDriverOption(
113
prompt_driver=AnthropicPromptDriver(model="claude-instant-1.2", api_key=os.environ["ANTHROPIC_API_KEY"]),
116
"ANTHROPIC_CLAUDE_2": TesterPromptDriverOption(
117
prompt_driver=AnthropicPromptDriver(model="claude-2.0", api_key=os.environ["ANTHROPIC_API_KEY"]),
120
"ANTHROPIC_CLAUDE_2.1": TesterPromptDriverOption(
121
prompt_driver=AnthropicPromptDriver(model="claude-2.1", api_key=os.environ["ANTHROPIC_API_KEY"]),
124
"ANTHROPIC_CLAUDE_3_OPUS": TesterPromptDriverOption(
125
prompt_driver=AnthropicPromptDriver(
126
model="claude-3-opus-20240229", api_key=os.environ["ANTHROPIC_API_KEY"]
130
"ANTHROPIC_CLAUDE_3_SONNET": TesterPromptDriverOption(
131
prompt_driver=AnthropicPromptDriver(
132
model="claude-3-sonnet-20240229", api_key=os.environ["ANTHROPIC_API_KEY"]
136
"ANTHROPIC_CLAUDE_3_HAIKU": TesterPromptDriverOption(
137
prompt_driver=AnthropicPromptDriver(
138
model="claude-3-haiku-20240307", api_key=os.environ["ANTHROPIC_API_KEY"]
142
"COHERE_COMMAND": TesterPromptDriverOption(
143
prompt_driver=CoherePromptDriver(model="command", api_key=os.environ["COHERE_API_KEY"]), enabled=True
145
"BEDROCK_TITAN": TesterPromptDriverOption(
146
prompt_driver=AmazonBedrockPromptDriver(
147
model="amazon.titan-tg1-large", prompt_model_driver=BedrockTitanPromptModelDriver()
151
"BEDROCK_CLAUDE_INSTANT": TesterPromptDriverOption(
152
prompt_driver=AmazonBedrockPromptDriver(
153
model="anthropic.claude-instant-v1", prompt_model_driver=BedrockClaudePromptModelDriver()
157
"BEDROCK_CLAUDE_2": TesterPromptDriverOption(
158
prompt_driver=AmazonBedrockPromptDriver(
159
model="anthropic.claude-v2", prompt_model_driver=BedrockClaudePromptModelDriver()
163
"BEDROCK_CLAUDE_2.1": TesterPromptDriverOption(
164
prompt_driver=AmazonBedrockPromptDriver(
165
model="anthropic.claude-v2:1", prompt_model_driver=BedrockClaudePromptModelDriver()
169
"BEDROCK_CLAUDE_3_SONNET": TesterPromptDriverOption(
170
prompt_driver=AmazonBedrockPromptDriver(
171
model="anthropic.claude-3-sonnet-20240229-v1:0", prompt_model_driver=BedrockClaudePromptModelDriver()
175
"BEDROCK_CLAUDE_3_HAIKU": TesterPromptDriverOption(
176
prompt_driver=AmazonBedrockPromptDriver(
177
model="anthropic.claude-3-haiku-20240307-v1:0", prompt_model_driver=BedrockClaudePromptModelDriver()
181
"BEDROCK_J2": TesterPromptDriverOption(
182
prompt_driver=AmazonBedrockPromptDriver(
183
model="ai21.j2-ultra", prompt_model_driver=BedrockJurassicPromptModelDriver()
187
"BEDROCK_LLAMA2_13B": TesterPromptDriverOption(
188
prompt_driver=AmazonBedrockPromptDriver(
189
model="meta.llama2-13b-chat-v1", prompt_model_driver=BedrockLlamaPromptModelDriver(), max_attempts=1
193
"BEDROCK_LLAMA2_70B": TesterPromptDriverOption(
194
prompt_driver=AmazonBedrockPromptDriver(
195
model="meta.llama2-70b-chat-v1", prompt_model_driver=BedrockLlamaPromptModelDriver(), max_attempts=1
199
"SAGEMAKER_LLAMA_7B": TesterPromptDriverOption(
200
prompt_driver=AmazonSageMakerPromptDriver(
201
model=os.environ["SAGEMAKER_LLAMA_ENDPOINT_NAME"],
202
prompt_model_driver=SageMakerLlamaPromptModelDriver(max_tokens=4096),
206
"SAGEMAKER_FALCON_7b": TesterPromptDriverOption(
207
prompt_driver=AmazonSageMakerPromptDriver(
208
model=os.environ["SAGEMAKER_FALCON_ENDPOINT_NAME"],
209
prompt_model_driver=SageMakerFalconPromptModelDriver(),
213
"GOOGLE_GEMINI_PRO": TesterPromptDriverOption(
214
prompt_driver=GooglePromptDriver(model="gemini-pro", api_key=os.environ["GOOGLE_API_KEY"]), enabled=True
217
TOOLKIT_TASK_CAPABLE_PROMPT_DRIVERS = get_enabled_prompt_drivers(
219
PROMPT_DRIVERS["OPENAI_CHAT_4"],
220
PROMPT_DRIVERS["OPENAI_CHAT_4_1106_PREVIEW"],
221
PROMPT_DRIVERS["AZURE_CHAT_4"],
222
PROMPT_DRIVERS["AZURE_CHAT_4_32K"],
223
PROMPT_DRIVERS["ANTHROPIC_CLAUDE_3_OPUS"],
224
PROMPT_DRIVERS["GOOGLE_GEMINI_PRO"],
227
TOOL_TASK_CAPABLE_PROMPT_DRIVERS = get_enabled_prompt_drivers(PROMPT_DRIVERS.values())
228
PROMPT_TASK_CAPABLE_PROMPT_DRIVERS = get_enabled_prompt_drivers(PROMPT_DRIVERS.values())
229
TEXT_SUMMARY_TASK_CAPABLE_PROMPT_DRIVERS = get_enabled_prompt_drivers(PROMPT_DRIVERS.values())
230
TEXT_QUERY_TASK_CAPABLE_PROMPT_DRIVERS = get_enabled_prompt_drivers(PROMPT_DRIVERS.values())
231
JSON_EXTRACTION_TASK_CAPABLE_PROMPT_DRIVERS = get_enabled_prompt_drivers(PROMPT_DRIVERS.values())
232
CSV_EXTRACTION_TASK_CAPABLE_PROMPT_DRIVERS = get_enabled_prompt_drivers(PROMPT_DRIVERS.values())
233
RULE_CAPABLE_PROMPT_DRIVERS = get_enabled_prompt_drivers(PROMPT_DRIVERS.values())
235
structure: Structure = field()
238
def prompt_driver_id_fn(cls, prompt_driver) -> str:
239
return f"{prompt_driver.__class__.__name__}-{prompt_driver.model}"
241
def verify_structure_output(self, structure) -> dict:
242
output_schema = Schema(
244
Literal("correct", description="Whether the output was correct or not."): bool,
246
"explanation", description="A brief explanation of why you felt the output was correct or not."
250
task_names = [task.__class__.__name__ for task in structure.tasks]
251
prompt = structure.input_task.input.to_text()
252
actual = structure.output_task.output.to_text()
253
rules = [rule.value for ruleset in structure.input_task.all_rulesets for rule in ruleset.rules]
261
f"Output a json object matching this schema: {output_schema.json_schema('Output Schema')}."
269
"Your objective is to determine whether an LLM generated an acceptable output for a given tasks, prompt, and rules."
271
Rule("The output does not need to be perfect, but it should be acceptable"),
272
Rule("Do not make any assumptions about how the output should be formatted."),
274
"Do not worry about the accuracy of the output, only that it is an appropriate response to the prompt."
279
prompt_driver=AzureOpenAiChatPromptDriver(
280
api_key=os.environ["AZURE_OPENAI_API_KEY_1"],
282
azure_deployment=os.environ["AZURE_OPENAI_4_DEPLOYMENT_ID"],
283
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"],
284
response_format="json_object",
288
"\nTasks: {{ task_names }}"
289
'\n{% if rules %}Rules: """{{ rules }}"""{% endif %}'
290
'\nPrompt: """{{ prompt }}"""'
291
'\nOutput: """{{ output }}"""',
295
"task_names": ", ".join(task_names),
296
"rules": ", ".join(rules),
300
logger_level=logging.DEBUG,
302
agent.logger.debug("Determining correctness of output.")
303
result = json.loads(agent.run().output_task.output.to_text())
304
explanation = result["explanation"]
306
agent.logger.debug(explanation)
310
def run(self, prompt, assert_correctness: bool = True) -> dict:
311
result = self.structure.run(prompt)
312
if isinstance(result.output_task.output, ErrorArtifact):
313
verified_result = {"correct": False, "explanation": f"ErrorArtifact: {result.output_task.output.to_text()}"}
315
verified_result = self.verify_structure_output(self.structure)
317
if assert_correctness:
318
assert verified_result["correct"]
320
return verified_result