wandb

Форк
0
/
pr-title-bot.py 
203 строки · 6.2 Кб
1
import argparse
2
import os
3
import sys
4
from typing import Optional, Tuple
5

6
if sys.version_info < (3, 8):
7
    from typing_extensions import Literal
8
else:
9
    from typing import Literal
10

11
from github import Github
12
from openai import OpenAI
13
from tenacity import retry, stop_after_attempt, wait_random_exponential
14

15
GITHUB_TOKEN = os.environ.get("GITHUB_API_TOKEN")
16
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
17

18
client = OpenAI(api_key=OPENAI_API_KEY)
19

20
Model = Literal["gpt-4", "gpt-3.5-turbo", "vicuna-7b-v1.1"]
21

22
CC_TYPES = os.environ.get(
23
    "CC_TYPES",
24
    ", ".join(
25
        [
26
            "feat",
27
            "fix",
28
            "docs",
29
            "style",
30
            "refactor",
31
            "perf",
32
            "test",
33
            "build",
34
            "ci",
35
            "chore",
36
            "revert",
37
            "security",
38
        ]
39
    ),
40
)
41

42
CC_SCOPES = os.environ.get(
43
    "CC_SCOPES",
44
    ", ".join(
45
        [
46
            "sdk",
47
            "cli",
48
            "public-api",
49
            "integrations",
50
            "artifacts",
51
            "media",
52
            "sweeps",
53
            "launch",
54
        ]
55
    ),
56
)
57

58

59
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
60
def chat_completion_with_backoff(**kwargs):
61
    """Call OpenAI's chat completion API with exponential backoff."""
62
    return client.chat.completions.create(**kwargs)
63

64

65
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
66
def get_pr_info(
67
    pr_number: int,
68
    repo_name: str = "wandb/wandb",
69
    get_diff: bool = True,
70
) -> Tuple[str, Optional[str]]:
71
    """Get the title and diff of a PR."""
72
    g = Github(GITHUB_TOKEN)
73
    repo = g.get_repo(repo_name)
74

75
    # Get specific pull request by number
76
    pr = repo.get_pull(pr_number)
77

78
    if not get_diff:
79
        return pr.title, None
80

81
    # Get the files
82
    files = pr.get_files()
83

84
    diff = "\n".join(
85
        [
86
            file.patch or file.filename  # use filename if patch is empty
87
            for file in files
88
            if "_pb2" not in file.filename  # ignore autogenerated protobuf files
89
        ]
90
    )
91
    pr_title = pr.title
92

93
    return pr_title, diff
94

95

96
def check_pr_title(
97
    pr_number: int,
98
    model: Model = "gpt-4",
99
    repo_name: str = "wandb/wandb",
100
) -> bool:
101
    """Check whether a PR title follows the conventional commit format."""
102
    pr_title, _ = get_pr_info(pr_number, repo_name=repo_name, get_diff=False)
103

104
    messages = [
105
        {
106
            "role": "system",
107
            "content": (
108
                "Your task is to check whether the title for a GitHub pull request "
109
                "follows the conventional commit format: "
110
                f"<type>(<scope>): <description>. The possible types are: {CC_TYPES}."
111
                f"The possible scopes are: {CC_SCOPES}."
112
                "The description must start with a verb in the imperative mood and be lower case."
113
                "The user will provide the current PR title. You must respond with 'yes' or 'no'."
114
            ),
115
        },
116
        {
117
            "role": "user",
118
            "content": f"{pr_title}",
119
        },
120
    ]
121

122
    response = chat_completion_with_backoff(
123
        model=model,
124
        messages=messages,
125
    )
126
    # fixme:
127
    is_compliant = response.choices[0]["message"]["content"].lower().strip()
128
    print(is_compliant)
129

130
    return is_compliant == "yes"
131

132

133
def generate_pr_title(
134
    pr_number: int,
135
    model: Model = "gpt-4",
136
    repo_name: str = "wandb/wandb",
137
) -> str:
138
    """Generate a PR title for a given PR number using the given model."""
139
    messages = [
140
        {
141
            "role": "system",
142
            "content": (
143
                "Your task is to write a title for a GitHub pull request "
144
                "that will follow the conventional commit format and capture the essence of the change: "
145
                f"<type>(<scope>): <description>. The possible types are: {CC_TYPES}."
146
                f"The possible scopes are: {CC_SCOPES}."
147
                "The description must start with a verb in the imperative mood and be lower case."
148
                "For context, the user will provide the current title and the diff of the pull request."
149
                "and their corresponding labels."
150
                "You must respond in the format: <type>(<scope>): <description>."
151
                "Be concise and specific. If you are unsure, keep the original title."
152
                "Even if you think the correct type or scope is missing, you must only use the provided options."
153
                "Be concise."
154
            ),
155
        },
156
        {
157
            "role": "user",
158
            "content": "Title: {{TITLE}}. Diff: {{DIFF}}.",
159
        },
160
    ]
161

162
    pr_title, diff = get_pr_info(pr_number, repo_name=repo_name)
163

164
    # todo: check context limit, strip diff if too long
165

166
    messages[-1]["content"] = messages[-1]["content"].replace("{{TITLE}}", pr_title)
167
    messages[-1]["content"] = messages[-1]["content"].replace("{{DIFF}}", diff)
168

169
    completion = chat_completion_with_backoff(
170
        model=model,
171
        messages=messages,
172
    )
173
    suggested_title = completion.choices[0]["message"]["content"]
174

175
    return suggested_title
176

177

178
if __name__ == "__main__":
179
    parser = argparse.ArgumentParser()
180
    # add two subparsers: one for the "generate" command and one for the "check" command
181
    subparsers = parser.add_subparsers(dest="command")
182

183
    generate_parser = subparsers.add_parser("generate", help="Generate PR title")
184
    generate_parser.add_argument("pr_number", type=int, help="Pull Request number")
185

186
    check_parser = subparsers.add_parser("check", help="Check PR title")
187
    check_parser.add_argument("pr_number", type=int, help="Pull Request number")
188

189
    args = parser.parse_args()
190

191
    if args.command == "generate":
192
        title = generate_pr_title(args.pr_number)
193
        print(title)
194
    elif args.command == "check":
195
        is_conventional_commit_compliant = check_pr_title(args.pr_number)
196
        if not is_conventional_commit_compliant:
197
            raise ValueError(
198
                "PR title is not compliant with the conventional commit recommendations. \n"
199
                "Comment on your PR with `/suggest-title` to get a suggestion or "
200
                "`/fix-title` to ask the pr-title-bot to fix it for you."
201
            )
202
    else:
203
        print("Invalid command. Use 'generate' or 'check'")
204

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.