6
from typing import Any, cast, Dict, List, Optional
7
from urllib.error import HTTPError
9
from github_utils import gh_fetch_url, gh_post_pr_comment, gh_query_issues_by_labels
10
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
11
from trymerge import get_pr_commit_sha, GitHubPR
14
# This is only a suggestion for now, not a strict requirement
20
RELEASE_BRANCH_REGEX = re.compile(r"release/(?P<version>.+)")
23
def parse_args() -> Any:
24
from argparse import ArgumentParser
26
parser = ArgumentParser("cherry pick a landed PR onto a release branch")
28
"--onto-branch", type=str, required=True, help="the target release branch"
31
"--github-actor", type=str, required=True, help="all the world's a stage"
35
choices=["regression", "critical", "fixnewfeature", "docs", "release"],
37
help="the cherry pick category",
39
parser.add_argument("pr_num", type=int)
44
help="the GitHub issue that the cherry pick fixes",
46
parser.add_argument("--dry-run", action="store_true")
48
return parser.parse_args()
51
def get_merge_commit_sha(repo: GitRepo, pr: GitHubPR) -> Optional[str]:
53
Return the merge commit SHA iff the PR has been merged. For simplicity, we
54
will only cherry pick PRs that have been merged into main
56
commit_sha = get_pr_commit_sha(repo, pr)
57
return commit_sha if pr.is_closed() else None
60
def get_release_version(onto_branch: str) -> Optional[str]:
62
Return the release version if the target branch is a release branch
64
m = re.match(RELEASE_BRANCH_REGEX, onto_branch)
65
return m.group("version") if m else ""
68
def get_tracker_issues(
69
org: str, project: str, onto_branch: str
70
) -> List[Dict[str, Any]]:
72
Find the tracker issue from the repo. The tracker issue needs to have the title
73
like [VERSION] Release Tracker following the convention on PyTorch
75
version = get_release_version(onto_branch)
79
tracker_issues = gh_query_issues_by_labels(org, project, labels=["release tracker"])
80
if not tracker_issues:
83
# Figure out the tracker issue from the list by looking at the title
84
return [issue for issue in tracker_issues if version in issue.get("title", "")]
95
dry_run: bool = False,
98
Create a local branch to cherry pick the commit and submit it as a pull request
100
current_branch = repo.current_branch()
101
cherry_pick_branch = create_cherry_pick_branch(
102
github_actor, repo, pr, commit_sha, onto_branch
106
org, project = repo.gh_owner_and_name()
110
cherry_pick_pr = submit_pr(repo, pr, cherry_pick_branch, onto_branch)
112
tracker_issues_comments = []
113
tracker_issues = get_tracker_issues(org, project, onto_branch)
114
for issue in tracker_issues:
115
issue_number = int(str(issue.get("number", "0")))
121
post_tracker_issue_comment(
133
comment_url = res.get("html_url", "")
135
tracker_issues_comments.append(comment_url)
137
msg = f"The cherry pick PR is at {cherry_pick_pr}"
139
msg += f" and it is linked with issue {fixes}."
140
elif classification in REQUIRES_ISSUE:
141
msg += f" and it is recommended to link a {classification} cherry pick PR with an issue."
143
if tracker_issues_comments:
144
msg += " The following tracker issues are updated:\n"
145
for tracker_issues_comment in tracker_issues_comments:
146
msg += f"* {tracker_issues_comment}\n"
148
post_pr_comment(org, project, pr.pr_num, msg, dry_run)
152
repo.checkout(branch=current_branch)
155
def create_cherry_pick_branch(
156
github_actor: str, repo: GitRepo, pr: GitHubPR, commit_sha: str, onto_branch: str
159
Create a local branch and cherry pick the commit. Return the name of the local
160
cherry picking branch.
162
repo.checkout(branch=onto_branch)
163
repo._run_git("submodule", "update", "--init", "--recursive")
165
# Remove all special characters if we want to include the actor in the branch name
166
github_actor = re.sub("[^0-9a-zA-Z]+", "_", github_actor)
168
cherry_pick_branch = f"cherry-pick-{pr.pr_num}-by-{github_actor}"
169
repo.create_branch_and_checkout(branch=cherry_pick_branch)
171
# We might want to support ghstack later
172
# We don't want to resolve conflicts here.
173
repo._run_git("cherry-pick", "-x", commit_sha)
174
repo.push(branch=cherry_pick_branch, dry_run=False)
176
return cherry_pick_branch
182
cherry_pick_branch: str,
186
Submit the cherry pick PR and return the link to the PR
188
org, project = repo.gh_owner_and_name()
190
default_msg = f"Cherry pick #{pr.pr_num} onto {onto_branch} branch"
191
title = pr.info.get("title", default_msg)
192
body = pr.info.get("body", default_msg)
195
response = gh_fetch_url(
196
f"https://api.github.com/repos/{org}/{project}/pulls",
201
"head": cherry_pick_branch,
204
headers={"Accept": "application/vnd.github.v3+json"},
208
cherry_pick_pr = response.get("html_url", "")
209
if not cherry_pick_pr:
211
f"Fail to find the cherry pick PR: {json.dumps(response)}"
214
return str(cherry_pick_pr)
216
except HTTPError as error:
217
msg = f"Fail to submit the cherry pick PR: {error}"
218
raise RuntimeError(msg) from error
222
org: str, project: str, pr_num: int, msg: str, dry_run: bool = False
223
) -> List[Dict[str, Any]]:
225
Post a comment on the PR itself to point to the cherry picking PR when success
226
or print the error when failure
228
internal_debugging = ""
230
run_url = os.getenv("GH_RUN_URL")
231
# Post a comment to tell folks that the PR is being cherry picked
232
if run_url is not None:
233
internal_debugging = "\n".join(
236
"<details><summary>Details for Dev Infra team</summary>",
237
f'Raised by <a href="{run_url}">workflow job</a>\n',
244
(f"### Cherry picking #{pr_num}", f"{msg}", "", f"{internal_debugging}")
246
return gh_post_pr_comment(org, project, pr_num, comment, dry_run)
249
def post_tracker_issue_comment(
257
dry_run: bool = False,
258
) -> List[Dict[str, Any]]:
260
Post a comment on the tracker issue (if any) to record the cherry pick
264
"Link to landed trunk PR (if applicable):",
265
f"* https://github.com/{org}/{project}/pull/{pr_num}",
267
"Link to release branch PR:",
268
f"* {cherry_pick_pr}",
270
"Criteria Category:",
271
" - ".join((classification.capitalize(), fixes.capitalize())),
274
return gh_post_pr_comment(org, project, issue_num, comment, dry_run)
281
repo = GitRepo(get_git_repo_dir(), get_git_remote_name())
282
org, project = repo.gh_owner_and_name()
284
pr = GitHubPR(org, project, pr_num)
287
commit_sha = get_merge_commit_sha(repo, pr)
290
f"Refuse to cherry pick #{pr_num} because it hasn't been merged yet"
304
except RuntimeError as error:
306
post_pr_comment(org, project, pr_num, str(error))
311
if __name__ == "__main__":