llama-factory

Форк
0
/
test_toolcall.py 
57 строк · 2.2 Кб
1
import json
2
from typing import Sequence
3

4
from openai import OpenAI
5
from transformers.utils.versions import require_version
6

7

8
require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
9

10

11
def calculate_gpa(grades: Sequence[str], hours: Sequence[int]) -> float:
12
    grade_to_score = {"A": 4, "B": 3, "C": 2}
13
    total_score, total_hour = 0, 0
14
    for grade, hour in zip(grades, hours):
15
        total_score += grade_to_score[grade] * hour
16
        total_hour += hour
17
    return total_score / total_hour
18

19

20
tool_map = {"calculate_gpa": calculate_gpa}
21

22

23
if __name__ == "__main__":
24
    client = OpenAI(
25
        api_key="0",
26
        base_url="http://localhost:8000/v1",
27
    )
28
    tools = [
29
        {
30
            "type": "function",
31
            "function": {
32
                "name": "calculate_gpa",
33
                "description": "Calculate the Grade Point Average (GPA) based on grades and credit hours",
34
                "parameters": {
35
                    "type": "object",
36
                    "properties": {
37
                        "grades": {"type": "array", "items": {"type": "string"}, "description": "The grades"},
38
                        "hours": {"type": "array", "items": {"type": "integer"}, "description": "The credit hours"},
39
                    },
40
                    "required": ["grades", "hours"],
41
                },
42
            },
43
        }
44
    ]
45
    messages = []
46
    messages.append({"role": "user", "content": "My grades are A, A, B, and C. The credit hours are 3, 4, 3, and 2."})
47
    result = client.chat.completions.create(messages=messages, model="test", tools=tools)
48
    tool_call = result.choices[0].message.tool_calls[0].function
49
    name, arguments = tool_call.name, json.loads(tool_call.arguments)
50
    messages.append(
51
        {"role": "function", "content": json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)}
52
    )
53
    tool_result = tool_map[name](**arguments)
54
    messages.append({"role": "tool", "content": json.dumps({"gpa": tool_result}, ensure_ascii=False)})
55
    result = client.chat.completions.create(messages=messages, model="test", tools=tools)
56
    print(result.choices[0].message.content)
57
    # Based on your grades and credit hours, your calculated Grade Point Average (GPA) is 3.4166666666666665.
58

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.