pytorch

Форк
0
/
cherry_pick.py 
312 строк · 9.1 Кб
1
#!/usr/bin/env python3
2

3
import json
4
import os
5
import re
6
from typing import Any, cast, Dict, List, Optional
7
from urllib.error import HTTPError
8

9
from github_utils import gh_fetch_url, gh_post_pr_comment, gh_query_issues_by_labels
10
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
11
from trymerge import get_pr_commit_sha, GitHubPR
12

13

14
# This is only a suggestion for now, not a strict requirement
15
REQUIRES_ISSUE = {
16
    "regression",
17
    "critical",
18
    "fixnewfeature",
19
}
20
RELEASE_BRANCH_REGEX = re.compile(r"release/(?P<version>.+)")
21

22

23
def parse_args() -> Any:
24
    from argparse import ArgumentParser
25

26
    parser = ArgumentParser("cherry pick a landed PR onto a release branch")
27
    parser.add_argument(
28
        "--onto-branch", type=str, required=True, help="the target release branch"
29
    )
30
    parser.add_argument(
31
        "--github-actor", type=str, required=True, help="all the world's a stage"
32
    )
33
    parser.add_argument(
34
        "--classification",
35
        choices=["regression", "critical", "fixnewfeature", "docs", "release"],
36
        required=True,
37
        help="the cherry pick category",
38
    )
39
    parser.add_argument("pr_num", type=int)
40
    parser.add_argument(
41
        "--fixes",
42
        type=str,
43
        default="",
44
        help="the GitHub issue that the cherry pick fixes",
45
    )
46
    parser.add_argument("--dry-run", action="store_true")
47

48
    return parser.parse_args()
49

50

51
def get_merge_commit_sha(repo: GitRepo, pr: GitHubPR) -> Optional[str]:
52
    """
53
    Return the merge commit SHA iff the PR has been merged. For simplicity, we
54
    will only cherry pick PRs that have been merged into main
55
    """
56
    commit_sha = get_pr_commit_sha(repo, pr)
57
    return commit_sha if pr.is_closed() else None
58

59

60
def get_release_version(onto_branch: str) -> Optional[str]:
61
    """
62
    Return the release version if the target branch is a release branch
63
    """
64
    m = re.match(RELEASE_BRANCH_REGEX, onto_branch)
65
    return m.group("version") if m else ""
66

67

68
def get_tracker_issues(
69
    org: str, project: str, onto_branch: str
70
) -> List[Dict[str, Any]]:
71
    """
72
    Find the tracker issue from the repo. The tracker issue needs to have the title
73
    like [VERSION] Release Tracker following the convention on PyTorch
74
    """
75
    version = get_release_version(onto_branch)
76
    if not version:
77
        return []
78

79
    tracker_issues = gh_query_issues_by_labels(org, project, labels=["release tracker"])
80
    if not tracker_issues:
81
        return []
82

83
    # Figure out the tracker issue from the list by looking at the title
84
    return [issue for issue in tracker_issues if version in issue.get("title", "")]
85

86

87
def cherry_pick(
88
    github_actor: str,
89
    repo: GitRepo,
90
    pr: GitHubPR,
91
    commit_sha: str,
92
    onto_branch: str,
93
    classification: str,
94
    fixes: str,
95
    dry_run: bool = False,
96
) -> None:
97
    """
98
    Create a local branch to cherry pick the commit and submit it as a pull request
99
    """
100
    current_branch = repo.current_branch()
101
    cherry_pick_branch = create_cherry_pick_branch(
102
        github_actor, repo, pr, commit_sha, onto_branch
103
    )
104

105
    try:
106
        org, project = repo.gh_owner_and_name()
107

108
        cherry_pick_pr = ""
109
        if not dry_run:
110
            cherry_pick_pr = submit_pr(repo, pr, cherry_pick_branch, onto_branch)
111

112
        tracker_issues_comments = []
113
        tracker_issues = get_tracker_issues(org, project, onto_branch)
114
        for issue in tracker_issues:
115
            issue_number = int(str(issue.get("number", "0")))
116
            if not issue_number:
117
                continue
118

119
            res = cast(
120
                Dict[str, Any],
121
                post_tracker_issue_comment(
122
                    org,
123
                    project,
124
                    issue_number,
125
                    pr.pr_num,
126
                    cherry_pick_pr,
127
                    classification,
128
                    fixes,
129
                    dry_run,
130
                ),
131
            )
132

133
            comment_url = res.get("html_url", "")
134
            if comment_url:
135
                tracker_issues_comments.append(comment_url)
136

137
        msg = f"The cherry pick PR is at {cherry_pick_pr}"
138
        if fixes:
139
            msg += f" and it is linked with issue {fixes}."
140
        elif classification in REQUIRES_ISSUE:
141
            msg += f" and it is recommended to link a {classification} cherry pick PR with an issue."
142

143
        if tracker_issues_comments:
144
            msg += " The following tracker issues are updated:\n"
145
            for tracker_issues_comment in tracker_issues_comments:
146
                msg += f"* {tracker_issues_comment}\n"
147

148
        post_pr_comment(org, project, pr.pr_num, msg, dry_run)
149

150
    finally:
151
        if current_branch:
152
            repo.checkout(branch=current_branch)
153

154

155
def create_cherry_pick_branch(
156
    github_actor: str, repo: GitRepo, pr: GitHubPR, commit_sha: str, onto_branch: str
157
) -> str:
158
    """
159
    Create a local branch and cherry pick the commit. Return the name of the local
160
    cherry picking branch.
161
    """
162
    repo.checkout(branch=onto_branch)
163
    repo._run_git("submodule", "update", "--init", "--recursive")
164

165
    # Remove all special characters if we want to include the actor in the branch name
166
    github_actor = re.sub("[^0-9a-zA-Z]+", "_", github_actor)
167

168
    cherry_pick_branch = f"cherry-pick-{pr.pr_num}-by-{github_actor}"
169
    repo.create_branch_and_checkout(branch=cherry_pick_branch)
170

171
    # We might want to support ghstack later
172
    # We don't want to resolve conflicts here.
173
    repo._run_git("cherry-pick", "-x", commit_sha)
174
    repo.push(branch=cherry_pick_branch, dry_run=False)
175

176
    return cherry_pick_branch
177

178

179
def submit_pr(
180
    repo: GitRepo,
181
    pr: GitHubPR,
182
    cherry_pick_branch: str,
183
    onto_branch: str,
184
) -> str:
185
    """
186
    Submit the cherry pick PR and return the link to the PR
187
    """
188
    org, project = repo.gh_owner_and_name()
189

190
    default_msg = f"Cherry pick #{pr.pr_num} onto {onto_branch} branch"
191
    title = pr.info.get("title", default_msg)
192
    body = pr.info.get("body", default_msg)
193

194
    try:
195
        response = gh_fetch_url(
196
            f"https://api.github.com/repos/{org}/{project}/pulls",
197
            method="POST",
198
            data={
199
                "title": title,
200
                "body": body,
201
                "head": cherry_pick_branch,
202
                "base": onto_branch,
203
            },
204
            headers={"Accept": "application/vnd.github.v3+json"},
205
            reader=json.load,
206
        )
207

208
        cherry_pick_pr = response.get("html_url", "")
209
        if not cherry_pick_pr:
210
            raise RuntimeError(
211
                f"Fail to find the cherry pick PR: {json.dumps(response)}"
212
            )
213

214
        return str(cherry_pick_pr)
215

216
    except HTTPError as error:
217
        msg = f"Fail to submit the cherry pick PR: {error}"
218
        raise RuntimeError(msg) from error
219

220

221
def post_pr_comment(
222
    org: str, project: str, pr_num: int, msg: str, dry_run: bool = False
223
) -> List[Dict[str, Any]]:
224
    """
225
    Post a comment on the PR itself to point to the cherry picking PR when success
226
    or print the error when failure
227
    """
228
    internal_debugging = ""
229

230
    run_url = os.getenv("GH_RUN_URL")
231
    # Post a comment to tell folks that the PR is being cherry picked
232
    if run_url is not None:
233
        internal_debugging = "\n".join(
234
            line
235
            for line in (
236
                "<details><summary>Details for Dev Infra team</summary>",
237
                f'Raised by <a href="{run_url}">workflow job</a>\n',
238
                "</details>",
239
            )
240
            if line
241
        )
242

243
    comment = "\n".join(
244
        (f"### Cherry picking #{pr_num}", f"{msg}", "", f"{internal_debugging}")
245
    )
246
    return gh_post_pr_comment(org, project, pr_num, comment, dry_run)
247

248

249
def post_tracker_issue_comment(
250
    org: str,
251
    project: str,
252
    issue_num: int,
253
    pr_num: int,
254
    cherry_pick_pr: str,
255
    classification: str,
256
    fixes: str,
257
    dry_run: bool = False,
258
) -> List[Dict[str, Any]]:
259
    """
260
    Post a comment on the tracker issue (if any) to record the cherry pick
261
    """
262
    comment = "\n".join(
263
        (
264
            "Link to landed trunk PR (if applicable):",
265
            f"* https://github.com/{org}/{project}/pull/{pr_num}",
266
            "",
267
            "Link to release branch PR:",
268
            f"* {cherry_pick_pr}",
269
            "",
270
            "Criteria Category:",
271
            " - ".join((classification.capitalize(), fixes.capitalize())),
272
        )
273
    )
274
    return gh_post_pr_comment(org, project, issue_num, comment, dry_run)
275

276

277
def main() -> None:
278
    args = parse_args()
279
    pr_num = args.pr_num
280

281
    repo = GitRepo(get_git_repo_dir(), get_git_remote_name())
282
    org, project = repo.gh_owner_and_name()
283

284
    pr = GitHubPR(org, project, pr_num)
285

286
    try:
287
        commit_sha = get_merge_commit_sha(repo, pr)
288
        if not commit_sha:
289
            raise RuntimeError(
290
                f"Refuse to cherry pick #{pr_num} because it hasn't been merged yet"
291
            )
292

293
        cherry_pick(
294
            args.github_actor,
295
            repo,
296
            pr,
297
            commit_sha,
298
            args.onto_branch,
299
            args.classification,
300
            args.fixes,
301
            args.dry_run,
302
        )
303

304
    except RuntimeError as error:
305
        if not args.dry_run:
306
            post_pr_comment(org, project, pr_num, str(error))
307
        else:
308
            raise error
309

310

311
if __name__ == "__main__":
312
    main()
313

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.