2
# Tests implemented in this file are relying on GitHub GraphQL APIs
3
# In order to avoid test flakiness, results of the queries
4
# are cached in gql_mocks.json
5
# PyTorch Lint workflow does not have GITHUB_TOKEN defined to avoid
6
# flakiness, so if you are making changes to merge_rules or
7
# GraphQL queries in trymerge.py, please make sure to delete `gql_mocks.json`
8
# And re-run the test locally with ones PAT
14
from hashlib import sha256
15
from typing import Any, Dict, List, Optional
16
from unittest import main, mock, skip, TestCase
17
from urllib.error import HTTPError
19
from github_utils import gh_graphql
20
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
24
find_matching_merge_rule,
26
get_drci_classifications,
31
main as trymerge_main,
32
MandatoryChecksMissingError,
36
remove_job_name_suffix,
41
if "GIT_REMOTE_URL" not in os.environ:
42
os.environ["GIT_REMOTE_URL"] = "https://github.com/pytorch/pytorch"
44
GQL_MOCKS = "gql_mocks.json.gz"
45
ROCKSET_MOCKS = "rockset_mocks.json.gz"
46
DRCI_MOCKS = "drci_mocks.json.gz"
50
fallback_function: Any,
55
gql_db_fname = os.path.join(os.path.dirname(__file__), file_name)
57
def get_mocked_queries() -> Any:
58
if not os.path.exists(gql_db_fname):
60
with gzip.open(gql_db_fname, encoding="utf-8", mode="rt") as f:
63
def save_mocked_queries(obj: Any) -> None:
64
with gzip.open(gql_db_fname, encoding="utf-8", mode="wt") as f:
65
json.dump(obj, f, indent=2)
68
key = key_function(*args)
69
mocked_queries = get_mocked_queries()
71
if key in mocked_queries:
72
return mocked_queries[key]
75
rc = fallback_function(*args)
76
except HTTPError as err:
77
if err.code == 401 or err.code == 403:
78
err_msg = f"If you are seeing this message during workflow run, please make sure to update {file_name}"
79
err_msg += f" locally, by deleting it and running {os.path.basename(__file__)} with"
80
err_msg += " GitHub Personal Access Token passed via GITHUB_TOKEN,"
81
err_msg += " the rockset api key passed via ROCKSET_API_KEY,"
82
err_msg += " and drci api key passed via DRCI_BOT_KEY environment variables"
84
os.getenv("GITHUB_TOKEN") is None
85
or os.getenv("ROCKSET_API_KEY") is None
86
or os.getenv("DRCI_BOT_KEY") is None
89
"Failed to update cached queries as GITHUB_TOKEN or ROCKSET_API_KEY or DRCI_BOT_KEY "
93
raise RuntimeError(err_msg) from err
94
mocked_queries[key] = rc
96
save_mocked_queries(mocked_queries)
101
def mocked_gh_graphql(query: str, **kwargs: Any) -> Any:
102
def key_function(query: str, kwargs: Any) -> str:
103
return f"query_sha={sha256(query.encode('utf-8')).hexdigest()} " + " ".join(
104
[f"{k}={kwargs[k]}" for k in sorted(kwargs.keys())]
107
def gh_graphql_wrapper(query: str, kwargs: Any) -> Any:
108
return gh_graphql(query, **kwargs)
110
return mock_query(gh_graphql_wrapper, GQL_MOCKS, key_function, query, kwargs)
113
def mocked_rockset_results(head_sha: str, merge_base: str, num_retries: int = 3) -> Any:
117
lambda x, y: f"{x} {y}",
123
def mocked_drci_classifications(pr_num: int, project: str, num_retries: int = 3) -> Any:
125
get_drci_classifications,
127
lambda x, y: f"{x} {y}",
133
def mock_parse_args(revert: bool = False, force: bool = False) -> Any:
135
def __init__(self) -> None:
141
self.reason = "this is for testing"
142
self.ignore_current = False
143
self.check_mergeability = False
148
def mock_remove_label(
149
org: str, repo: str, pr_num: str, label: str, dry_run: bool
158
dry_run: bool = False,
159
comment_id: Optional[int] = None,
160
reason: Optional[str] = None,
168
dry_run: bool = False,
169
skip_mandatory_checks: bool = False,
170
comment_id: Optional[int] = None,
171
timeout_minutes: int = 400,
172
stale_pr_days: int = 3,
173
ignore_current: bool = False,
178
def mock_gh_get_info() -> Any:
181
"isCrossRepository": False,
182
"headRefName": "foo",
183
"baseRefName": "bar",
184
"baseRepository": {"defaultBranchRef": {"name": "bar"}},
185
"files": {"nodes": [], "pageInfo": {"hasNextPage": False}},
190
def mocked_read_merge_rules_NE(repo: Any, org: str, project: str) -> List[MergeRule]:
193
name="mock with nonexistent check",
196
mandatory_checks_name=["Lint", "Facebook CLA Check", "nonexistent"],
197
ignore_flaky_failures=True,
202
def mocked_read_merge_rules(repo: Any, org: str, project: str) -> List[MergeRule]:
207
approved_by=["pytorch/metamates", "ngimel"],
208
mandatory_checks_name=[
210
"pull / linux-xenial-cuda11.3-py3.7-gcc7 / build",
212
ignore_flaky_failures=True,
216
patterns=[".github/ci_commit_pins/xla.txt"],
217
approved_by=["pytorchbot"],
218
mandatory_checks_name=[
221
"pull / linux-focal-py3_8-clang9-xla / build",
222
"pull / linux-focal-py3_8-clang9-xla / test (xla, 1, 1, linux.12xlarge)",
224
ignore_flaky_failures=True,
229
def mocked_read_merge_rules_approvers(
230
repo: Any, org: str, project: str
234
name="Core Reviewers",
236
approved_by=["1", "2", "3", "4", "5", "6"],
237
mandatory_checks_name=[
243
name="Core Maintainers",
245
approved_by=["1", "2", "malfet"],
246
mandatory_checks_name=[
254
def mocked_read_merge_rules_raise(repo: Any, org: str, project: str) -> List[MergeRule]:
255
raise RuntimeError("testing")
258
def xla_merge_rules(repo: Any, org: str, project: str) -> List[MergeRule]:
261
name=" OSS CI / pytorchbot / XLA",
262
patterns=[".github/ci_commit_pins/xla.txt"],
263
approved_by=["pytorchbot"],
264
mandatory_checks_name=[
267
"pull / linux-bionic-py3_8-clang8-xla / build",
268
"pull / linux-bionic-py3_8-clang8-xla / test (xla, 1, 1, linux.4xlarge)",
269
"inductor / cuda11.8-py3.10-gcc7-sm86 / test (inductor_torchbench_dynamic, 1, 1, linux.g5.4xlarge.nvidia.gpu)",
271
ignore_flaky_failures=False,
276
def empty_rockset_results(head_sha: str, merge_base: str) -> List[Dict[str, Any]]:
280
class DummyGitRepo(GitRepo):
281
def __init__(self) -> None:
282
super().__init__(get_git_repo_dir(), get_git_remote_name())
284
def commits_resolving_gh_pr(self, pr_num: int) -> List[str]:
285
return ["FakeCommitSha"]
287
def commit_message(self, ref: str) -> str:
288
return "super awsome commit message"
291
@mock.patch("trymerge.get_rockset_results", side_effect=empty_rockset_results)
292
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
294
"trymerge.get_drci_classifications", side_effect=mocked_drci_classifications
296
class TestTryMerge(TestCase):
297
def test_merge_rules_valid(self, *args: Any) -> None:
298
"Test that merge_rules.yaml can be parsed"
299
repo = DummyGitRepo()
300
merge_rules = read_merge_rules(repo, "pytorch", "pytorch")
301
self.assertGreater(len(merge_rules), 1)
303
@mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules)
304
def test_match_rules(self, *args: Any) -> None:
305
"Tests that PR passes merge rules"
306
pr = GitHubPR("pytorch", "pytorch", 109999)
307
repo = DummyGitRepo()
308
self.assertTrue(find_matching_merge_rule(pr, repo) is not None)
310
@mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules_raise)
311
def test_read_merge_rules_fails(self, *args: Any) -> None:
312
"Tests that PR fails to read the merge rules"
313
pr = GitHubPR("pytorch", "pytorch", 77700)
314
repo = DummyGitRepo()
315
self.assertRaisesRegex(
316
RuntimeError, "testing", lambda: find_matching_merge_rule(pr, repo)
320
"trymerge.read_merge_rules", side_effect=mocked_read_merge_rules_approvers
322
def test_match_rules_approvers(self, *args: Any) -> None:
323
"Tests that PR has the necessary approvers"
324
repo = DummyGitRepo()
326
pr = GitHubPR("pytorch", "pytorch", 115329)
327
# Test that all potential approvers across all rules are listed if the
328
# PR doesn't have one of them
329
for mock_rule in ["Core Reviewers", "Core Maintainers"]:
330
self.assertRaisesRegex(
333
lambda: find_matching_merge_rule(pr, repo),
336
pr = GitHubPR("pytorch", "pytorch", 115495)
337
# Test that PR with the correct approvers doesn't raise any exception
338
self.assertTrue(find_matching_merge_rule(pr, repo) is not None)
340
@mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules)
341
def test_lint_fails(self, *args: Any) -> None:
342
"Tests that PR fails mandatory lint check"
343
pr = GitHubPR("pytorch", "pytorch", 90791)
344
repo = DummyGitRepo()
345
self.assertRaises(RuntimeError, lambda: find_matching_merge_rule(pr, repo))
347
def test_get_last_comment(self, *args: Any) -> None:
348
"Tests that last comment can be fetched"
349
pr = GitHubPR("pytorch", "pytorch", 71759)
350
comment = pr.get_last_comment()
351
self.assertEqual(comment.author_login, "github-actions")
352
self.assertIsNone(comment.editor_login)
353
self.assertTrue("You've committed this PR" in comment.body_text)
355
def test_get_author_null(self, *args: Any) -> None:
356
"""Tests that PR author can be computed
357
If reply contains NULL
359
pr = GitHubPR("pytorch", "pytorch", 71759)
360
author = pr.get_author()
361
self.assertTrue(author is not None)
362
self.assertTrue("@" in author)
363
self.assertTrue(pr.get_diff_revision() is None)
365
# PR with multiple contributors, but creator id is not among authors
366
pr = GitHubPR("pytorch", "pytorch", 75095)
367
self.assertEqual(pr.get_pr_creator_login(), "mruberry")
368
author = pr.get_author()
369
self.assertTrue(author is not None)
371
def test_large_diff(self, *args: Any) -> None:
372
"Tests that PR with 100+ files can be fetched"
373
pr = GitHubPR("pytorch", "pytorch", 73099)
374
self.assertTrue(pr.get_changed_files_count() > 100)
375
flist = pr.get_changed_files()
376
self.assertEqual(len(flist), pr.get_changed_files_count())
378
def test_internal_changes(self, *args: Any) -> None:
379
"Tests that PR with internal changes is detected"
380
pr = GitHubPR("pytorch", "pytorch", 110140)
381
self.assertTrue(pr.has_internal_changes())
383
def test_comments_pagination(self, *args: Any) -> None:
384
"Tests that PR with 50+ comments can be fetched"
385
pr = GitHubPR("pytorch", "pytorch", 31093)
386
self.assertGreater(len(pr.get_comments()), 50)
388
def test_gql_complexity(self, *args: Any) -> None:
389
"Fetch comments and conclusions for PR with 60 commits"
390
# Previous version of GrapQL query used to cause HTTP/502 error
391
# see https://gist.github.com/malfet/9b93bc7eeddeaf1d84546efc4f0c577f
392
pr = GitHubPR("pytorch", "pytorch", 68111)
393
self.assertGreater(len(pr.get_comments()), 20)
394
# NS(09/27/2023): GitHub seems to recycle older checkruns
395
# https://github.com/pytorch/pytorch/pull/68111/checks shows 0 runs
396
# self.assertGreater(len(pr.get_checkrun_conclusions()), 3)
397
self.assertGreater(pr.get_commit_count(), 60)
399
@skip("GitHub doesn't keep this data anymore")
400
def test_gql_retrieve_checksuites(self, *args: Any) -> None:
401
"Fetch comments and conclusions for PR with 60 commits"
402
pr = GitHubPR("pytorch", "pytorch", 94787)
403
self.assertEqual(len(pr.get_checkrun_conclusions()), 182)
405
def test_team_members(self, *args: Any) -> None:
406
"Test fetching team members works"
407
dev_infra_team = gh_get_team_members("pytorch", "pytorch-dev-infra")
408
self.assertGreater(len(dev_infra_team), 2)
409
with self.assertWarns(Warning):
410
non_existing_team = gh_get_team_members("pytorch", "qwertyuiop")
411
self.assertEqual(len(non_existing_team), 0)
413
def test_get_author_many_commits(self, *args: Any) -> None:
414
"""Tests that authors for all commits can be fetched"""
415
pr = GitHubPR("pytorch", "pytorch", 76118)
416
authors = pr.get_authors()
417
self.assertGreater(pr.get_commit_count(), 100)
418
self.assertGreater(len(authors), 50)
419
self.assertTrue("@" in pr.get_author())
421
@mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules_NE)
422
def test_pending_status_check(self, *args: Any) -> None:
423
"""Tests that PR with nonexistent/pending status checks fails with the right reason."""
424
pr = GitHubPR("pytorch", "pytorch", 76118)
425
repo = DummyGitRepo()
426
self.assertRaisesRegex(
427
MandatoryChecksMissingError,
428
".*are pending/not yet run.*",
429
lambda: find_matching_merge_rule(pr, repo),
432
def test_get_author_many_reviews(self, *args: Any) -> None:
433
"""Tests that all reviews can be fetched"""
434
pr = GitHubPR("pytorch", "pytorch", 76123)
435
approved_by = pr.get_approved_by()
436
self.assertGreater(len(approved_by), 0)
437
assert pr._reviews is not None # to pacify mypy
438
self.assertGreater(len(pr._reviews), 100)
440
def get_co_authors(self, *args: Any) -> None:
441
"""Tests that co-authors are recognized"""
442
pr = GitHubPR("pytorch", "pytorch", 118347)
443
authors = pr.get_authors()
444
self.assertIn("kit1980", authors)
445
self.assertIn("Co-authored-by:", pr.gen_commit_message())
447
def test_get_checkruns_many_runs(self, *args: Any) -> None:
448
"""Tests that all checkruns can be fetched"""
449
pr = GitHubPR("pytorch", "pytorch", 105260)
450
conclusions = pr.get_checkrun_conclusions()
451
self.assertEqual(len(conclusions), 221)
453
"pull / linux-docs / build-docs-cpp-false" in conclusions.keys()
456
def test_cancelled_gets_ignored(self, *args: Any) -> None:
457
"""Tests that cancelled workflow does not override existing successfull status"""
458
pr = GitHubPR("pytorch", "pytorch", 110367)
459
conclusions = pr.get_checkrun_conclusions()
460
lint_checks = [name for name in conclusions.keys() if "Lint" in name]
461
self.assertTrue(len(lint_checks) > 0)
463
all(conclusions[name].status == "SUCCESS" for name in lint_checks)
466
def test_get_review_comment_by_id(self, *args: Any) -> None:
467
"""Tests that even if the comment requested was actually a review instead of a simple comment, we can still find it"""
468
pr = GitHubPR("pytorch", "pytorch", 107070)
469
review_comment_id = 1582767635
470
comment = pr.get_comment_by_id(review_comment_id)
471
self.assertIsNotNone(comment)
473
@mock.patch("trymerge.gh_get_pr_info", return_value=mock_gh_get_info())
474
@mock.patch("trymerge.parse_args", return_value=mock_parse_args(True, False))
475
@mock.patch("trymerge.try_revert", side_effect=mock_revert)
476
def test_main_revert(self, mock_revert: Any, *args: Any) -> None:
478
mock_revert.assert_called_once()
480
@mock.patch("trymerge.gh_get_pr_info", return_value=mock_gh_get_info())
481
@mock.patch("trymerge.parse_args", return_value=mock_parse_args(False, True))
482
@mock.patch("trymerge.gh_remove_label", side_effect=mock_remove_label)
483
@mock.patch("trymerge.merge", side_effect=mock_merge)
485
self, mock_merge: Any, mock_parse_args: Any, *args: Any
488
mock_merge.assert_called_once_with(
492
skip_mandatory_checks=True,
494
ignore_current=False,
497
@mock.patch("trymerge.gh_get_pr_info", return_value=mock_gh_get_info())
498
@mock.patch("trymerge.parse_args", return_value=mock_parse_args(False, False))
499
@mock.patch("trymerge.gh_remove_label", side_effect=mock_remove_label)
500
@mock.patch("trymerge.merge", side_effect=mock_merge)
501
def test_main_merge(self, mock_merge: Any, *args: Any) -> None:
503
mock_merge.assert_called_once_with(
507
skip_mandatory_checks=False,
509
ignore_current=False,
512
@mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules)
513
def test_revert_rules(self, *args: Any) -> None:
514
"""Tests that reverts from collaborators are allowed"""
515
pr = GitHubPR("pytorch", "pytorch", 79694)
516
repo = DummyGitRepo()
517
self.assertIsNotNone(validate_revert(repo, pr, comment_id=1189459845))
519
def test_get_changed_files(self, *args: Any) -> None:
521
Tests that the list changed files in a PR doesn't include duplicates
523
pr = GitHubPR("pytorch", "pytorch", 95233)
525
changed_files = pr.get_changed_files()
526
except RuntimeError as error:
527
self.fail(f"get_changed_files throws an exception: {error}")
529
self.assertEqual(len(changed_files), pr.get_changed_files_count())
531
def test_revert_codev_abandoned_diff_succeeds(self, *args: Any) -> None:
532
pr = GitHubPR("pytorch", "pytorch", 100652)
534
class GitRepoCoDev(DummyGitRepo):
535
def commit_message(self, ref: str) -> str:
538
repo = GitRepoCoDev()
539
validate_revert(repo, pr, comment_id=1588195237)
541
def test_pr_changed_submodule_detection(self, *args: Any) -> None:
542
# Updates submodule during dev-cycle but reverts it later
543
pr = GitHubPR("pytorch", "pytorch", 95045)
544
self.assertEqual(pr.get_changed_submodules(), [])
545
self.assertFalse(pr.has_invalid_submodule_updates())
548
pr = GitHubPR("pytorch", "pytorch", 94939)
549
self.assertEqual(pr.get_changed_submodules(), ["third_party/ideep"])
550
self.assertTrue(pr.has_invalid_submodule_updates())
552
# Automated submodule update
553
pr = GitHubPR("pytorch", "pytorch", 91051)
554
self.assertEqual(pr.get_changed_submodules(), ["third_party/kineto"])
555
self.assertFalse(pr.has_invalid_submodule_updates())
557
def test_remove_job_name_suffix(self, *args: Any) -> None:
560
"name": "linux-bionic-cuda12.1-py3.10-gcc9-sm86 / test (default, 1, 5, linux.g5.4xlarge.nvidia.gpu)",
561
"expected": "linux-bionic-cuda12.1-py3.10-gcc9-sm86 / test (default)",
564
"name": "android-emulator-build-test / build-and-test (default, 1, 1, ubuntu-20.04-16x)",
565
"expected": "android-emulator-build-test / build-and-test (default)",
568
"name": "linux-focal-rocm5.4.2-py3.8 / build",
569
"expected": "linux-focal-rocm5.4.2-py3.8 / build",
572
"name": "libtorch-cpu-shared-with-deps-release-build",
573
"expected": "libtorch-cpu-shared-with-deps-release-build",
576
"name": "manywheel-py3_8-cuda11_8-test / test",
577
"expected": "manywheel-py3_8-cuda11_8-test / test",
580
"name": "lintrunner / linux-job",
581
"expected": "lintrunner / linux-job",
584
"name": "Test `run_test.py` is usable without boto3/rockset",
585
"expected": "Test `run_test.py` is usable without boto3/rockset",
589
for case in test_cases:
590
self.assertEqual(case["expected"], remove_job_name_suffix(case["name"]))
592
def test_get_merge_base(self, *args: Any) -> None:
593
pr = GitHubPR("pytorch", "pytorch", 104121)
595
mock_merge_base = "mocked-sha"
597
"trymerge.gh_fetch_merge_base", return_value=mock_merge_base
598
) as mocked_gh_fetch_merge_base:
599
self.assertEqual(mock_merge_base, pr.get_merge_base())
601
# Make sure that consecutive calls will use the same merge base instead of
602
# making another query
603
self.assertEqual(mock_merge_base, pr.get_merge_base())
604
mocked_gh_fetch_merge_base.assert_called_once()
607
@mock.patch("trymerge.get_rockset_results", side_effect=mocked_rockset_results)
608
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
609
@mock.patch("trymerge.gh_fetch_merge_base", return_value="")
611
"trymerge.get_drci_classifications", side_effect=mocked_drci_classifications
613
class TestBypassFailures(TestCase):
614
def test_get_classifications(self, *args: Any) -> None:
615
pr = GitHubPR("pytorch", "pytorch", 109584)
616
checks = pr.get_checkrun_conclusions()
617
checks = get_classifications(
625
"pull / linux-focal-py3.11-clang10 / test (dynamo, 1, 2, linux.2xlarge)"
631
"trunk / win-vs2019-cpu-py3 / test (default, 2, 3, windows.4xlarge.nonephemeral)"
637
"pull / linux-jammy-py3.8-gcc11 / test (distributed, 1, 2, linux.2xlarge)"
643
"pull / linux-focal-cuda11.8-py3.10-gcc9 / test (distributed, 1, 3, linux.8xlarge.nvidia.gpu)"
648
# Set the threshold larger or equal to the number of ok failures
649
pending, failed, ignorable = categorize_checks(
650
checks, list(checks.keys()), ok_failed_checks_threshold=6
652
self.assertTrue(len(pending) == 0)
653
self.assertTrue(len(failed) == 0)
654
self.assertTrue(len(ignorable["FLAKY"]) == 4)
655
self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 2)
657
# Not set any threshold, defaults to -1 to ignore all flaky and broken trunk failures
658
pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
659
self.assertTrue(len(pending) == 0)
660
self.assertTrue(len(failed) == 0)
661
self.assertTrue(len(ignorable["FLAKY"]) == 4)
662
self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 2)
664
# Set the threshold lower than the number of ok failures
665
pending, failed, ignorable = categorize_checks(
666
checks, list(checks.keys()), ok_failed_checks_threshold=1
668
self.assertTrue(len(pending) == 0)
669
self.assertTrue(len(failed) == 6)
670
self.assertTrue(len(ignorable["FLAKY"]) == 4)
671
self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 2)
673
# Set the threshold to 0 like when ignore_flaky_failures is on
674
pending, failed, ignorable = categorize_checks(
675
checks, list(checks.keys()), ok_failed_checks_threshold=1
677
self.assertTrue(len(pending) == 0)
678
self.assertTrue(len(failed) == 6)
679
self.assertTrue(len(ignorable["FLAKY"]) == 4)
680
self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 2)
682
def test_get_classifications_flaky_fullname(self, *args: Any) -> None:
683
pr = GitHubPR("pytorch", "pytorch", 110362)
684
checks = pr.get_checkrun_conclusions()
685
checks = get_classifications(
691
pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
692
self.assertTrue(len(pending) == 0)
693
self.assertTrue(len(failed) == 0)
694
self.assertTrue(len(ignorable["FLAKY"]) == 1)
696
def test_get_classifications_invalid_cancel(self, *args: Any) -> None:
697
pr = GitHubPR("pytorch", "pytorch", 110367)
698
checks = pr.get_checkrun_conclusions()
699
checks = get_classifications(
705
pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
706
self.assertTrue(len(pending) == 0)
707
self.assertTrue(len(failed) == 0)
708
self.assertTrue(len(ignorable["FLAKY"]) == 0)
709
self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 0)
710
self.assertTrue(len(ignorable["UNSTABLE"]) == 3)
712
def test_get_classifications_similar_failures(self, *args: Any) -> None:
713
pr = GitHubPR("pytorch", "pytorch", 109750)
714
checks = pr.get_checkrun_conclusions()
715
checks = get_classifications(
721
pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
722
self.assertTrue(len(pending) == 0)
723
self.assertTrue(len(failed) == 0)
724
self.assertTrue(len(ignorable["FLAKY"]) == 1)
726
def test_get_classifications_unstable(self, *args: Any) -> None:
727
pr = GitHubPR("pytorch", "pytorch", 104312)
728
checks = pr.get_checkrun_conclusions()
729
checks = get_classifications(
735
workflow_name = "linux-bionic-cuda12.1-py3.10-gcc9-bazel-test"
736
job_name = "build-and-test (default, 1, 1, linux.4xlarge.nvidia.gpu, unstable)"
738
checks[f"pull / {workflow_name} / {job_name}"].classification == "UNSTABLE"
740
pending, failed, ignorable = categorize_checks(
741
checks, list(checks.keys()), ok_failed_checks_threshold=1
743
self.assertTrue(len(pending) == 0)
744
self.assertTrue(len(failed) == 0)
745
self.assertTrue(len(ignorable["UNSTABLE"]) == 1)
747
# Add another test case where there is no unstable keyword in the job name, but
748
# the job has already been marked as unstable
749
pr = GitHubPR("pytorch", "executorch", 3318)
750
checks = pr.get_checkrun_conclusions()
751
checks = get_classifications(
758
workflow_name = "test-llama-app"
759
job_name = "mobile-job (android)"
761
checks[f"Android / {workflow_name} / {job_name}"].classification
764
pending, failed, ignorable = categorize_checks(
765
checks, list(checks.keys()), ok_failed_checks_threshold=1
767
self.assertTrue(len(pending) == 0)
768
self.assertTrue(len(failed) == 0)
769
self.assertTrue(len(ignorable["UNSTABLE"]) == 1)
771
def test_get_classifications_broken_trunk(self, *args: Any) -> None:
772
# The mock merge base is the actual value returned by gh_fetch_merge_base
775
# This PR had one broken trunk failure but it was run on a different shard
776
# than the one on the base commit. This should still count as broken trunk
778
"related_failure_count": 0,
779
"flaky_or_broken_trunk": 1,
782
# This PR had one broken trunk failure and it used ghstack
784
"related_failure_count": 0,
785
"flaky_or_broken_trunk": 1,
788
# The failure on the merge base was retried successfully and
789
# its conclusion changed from failure to success. We want to
790
# keep the failure record from the merge base so that it can
791
# be used to detect broken trunk
793
"related_failure_count": 0,
794
"flaky_or_broken_trunk": 1,
797
# This PR used Dr.CI broken trunk classification
799
"related_failure_count": 1,
800
"flaky_or_broken_trunk": 1,
804
for case in test_cases:
805
pr_num = case["pr_num"]
806
related_failure_count = case["related_failure_count"]
807
flaky_or_broken_trunk = case["flaky_or_broken_trunk"]
809
pr = GitHubPR("pytorch", "pytorch", pr_num)
810
checks = pr.get_checkrun_conclusions()
811
checks = get_classifications(
818
pending, failed, _ = categorize_checks(checks, list(checks.keys()))
819
self.assertTrue(len(pending) == 0)
820
self.assertTrue(len(failed) == related_failure_count)
822
# When the ok_failed_checks_threshold is set to 0, the broken trunk failure
824
pending, failed, _ = categorize_checks(
825
checks, list(checks.keys()), ok_failed_checks_threshold=0
827
self.assertTrue(len(pending) == 0)
829
len(failed) == flaky_or_broken_trunk + related_failure_count
832
def test_ignore_current(self, *args: Any) -> None:
833
# Test various interactions of the failure classifier to ensure that ignore
834
# current checks takes place after other classifications: flaky, unstable,
835
# or broken trunk. Only actual new failures should be kept in the list of
836
# ignore current checks to use to record force merge with actual failures
837
flaky = "pull / linux-focal-cuda11.8-py3.10-gcc9 / test (distributed, 1, 3, linux.8xlarge.nvidia.gpu)"
839
"pull / linux-focal-py3.11-clang10 / test (dynamo, 1, 2, linux.2xlarge)"
842
pr = GitHubPR("pytorch", "pytorch", 109584)
843
checks = pr.get_checkrun_conclusions()
845
# Known flaky failure takes precedence over ignore current (need to set the
846
# merge base here to get the results from Rockset, and that categorize the
847
# broken trunk failure too
848
checks = get_classifications(
852
[broken_trunk, flaky],
854
self.assertTrue(checks[flaky].classification == "FLAKY")
855
self.assertTrue(checks[broken_trunk].classification == "BROKEN_TRUNK")
856
_, failed, ignorable = categorize_checks(checks, list(checks.keys()))
857
self.assertTrue(len(failed) == 0)
858
self.assertTrue(len(ignorable["IGNORE_CURRENT_CHECK"]) == 0)
859
self.assertTrue(len(ignorable["FLAKY"]) == 4)
860
self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 2)
862
def test_get_classifications_wrong_workflow_name(self, *args: Any) -> None:
863
pr = GitHubPR("pytorch", "pytorch", 123104)
864
checks = pr.get_checkrun_conclusions()
866
check_name = "linux-binary-conda / conda-py3_8-cuda11_8-build / build"
867
check_name_workflow_path = ".github/workflows/generated-linux-binary-conda-nightly.yml / conda-py3_8-cuda11_8-build / build"
869
# Mock a check where the workflow name uses the full path
870
checks[check_name_workflow_path] = JobCheckState(
871
check_name_workflow_path,
872
checks[check_name].url,
873
checks[check_name].status,
874
checks[check_name].classification,
875
checks[check_name].job_id,
876
checks[check_name].title,
877
checks[check_name].summary,
879
del checks[check_name]
881
checks = get_classifications(
887
pending, failed, ignorable = categorize_checks(
892
self.assertTrue(len(pending) == 0)
893
self.assertTrue(len(failed) == 0)
894
self.assertTrue(len(ignorable["FLAKY"]) == 1)
895
self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 0)
897
def test_ignore_failures_older_run_same_workflow(self, *args: Any) -> None:
898
pr = GitHubPR("pytorch", "pytorch", 129013)
899
checks = pr.get_checkrun_conclusions()
900
checks = get_classifications(
906
pending, failed, ignorable = categorize_checks(
910
self.assertTrue(len(pending) == 0)
911
self.assertTrue(len(failed) == 0)
912
self.assertTrue(len(ignorable["FLAKY"]) == 2)
913
self.assertTrue(len(ignorable["UNSTABLE"]) == 13)
915
@mock.patch("trymerge.read_merge_rules", side_effect=xla_merge_rules)
916
def test_dont_ignore_flaky_failures(self, *args: Any) -> None:
918
Regression test for https://github.com/pytorch/test-infra/issues/4126
920
pr = GitHubPR("pytorch", "pytorch", 105312)
921
repo = DummyGitRepo()
922
# Check that failure is classified as flaky but still raises exception
923
with warnings.catch_warnings(record=True) as w, self.assertRaises(RuntimeError):
924
rule = find_matching_merge_rule(pr, repo)
925
self.assertEqual(len(w), 1)
927
"1 checks failed but were likely due flakiness or broken trunk",
932
@mock.patch("trymerge.get_rockset_results", side_effect=mocked_rockset_results)
933
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
934
@mock.patch("trymerge.gh_fetch_merge_base", return_value="")
935
@mock.patch("trymerge.get_drci_classifications", return_value={})
936
class TestBypassFailuresOnSandCastle(TestCase):
937
def test_get_classifications(self, *args: Any) -> None:
938
pr = GitHubPR("pytorch", "pytorch", 111467)
939
checks = pr.get_checkrun_conclusions()
940
checks = get_classifications(
946
pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
947
self.assertTrue(len(pending) == 0)
948
self.assertTrue(len(failed) == 0)
949
self.assertTrue(len(ignorable["FLAKY"]) == 1)
950
self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 1)
952
def test_get_classifications_drci_checkrun_not_found(self, *args: Any) -> None:
953
pr = GitHubPR("pytorch", "pytorch", 111467)
956
checks = pr.get_checkrun_conclusions()
957
checks[DRCI_CHECKRUN_NAME] = JobCheckState(
966
checks = get_classifications(
972
pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
973
self.assertTrue(len(pending) == 0)
974
self.assertTrue(len(failed) == 2)
977
checks = pr.get_checkrun_conclusions()
978
checks[DRCI_CHECKRUN_NAME] = JobCheckState(
987
checks = get_classifications(
993
pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
994
self.assertTrue(len(pending) == 0)
995
self.assertTrue(len(failed) == 2)
998
checks = pr.get_checkrun_conclusions()
999
del checks[DRCI_CHECKRUN_NAME]
1000
checks = get_classifications(
1006
pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
1007
self.assertTrue(len(pending) == 0)
1008
self.assertTrue(len(failed) == 2)
1011
@mock.patch("trymerge.get_rockset_results", side_effect=mocked_rockset_results)
1012
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
1013
@mock.patch("trymerge.gh_fetch_merge_base", return_value="")
1015
"trymerge.get_drci_classifications", side_effect=mocked_drci_classifications
1017
class TestGitHubPRGhstackDependencies(TestCase):
1018
def test_pr_dependencies(self, *args: Any) -> None:
1019
pr = GitHubPR("pytorch", "pytorch", 106068)
1020
msg = pr.gen_commit_message(filter_ghstack=True)
1023
f"{pr.get_title()} (#106068)\n\n{RE_GHSTACK_DESC.sub('', pr.get_body())}\n"
1024
"Pull Request resolved: https://github.com/pytorch/pytorch/pull/106068\n"
1025
"Approved by: https://github.com/ezyang, https://github.com/fegin\n",
1028
def test_pr_dependencies_ghstack(self, *args: Any) -> None:
1029
pr0 = GitHubPR("pytorch", "pytorch", 106032)
1030
pr1 = GitHubPR("pytorch", "pytorch", 106033)
1031
pr2 = GitHubPR("pytorch", "pytorch", 106034)
1032
pr = GitHubPR("pytorch", "pytorch", 106068)
1033
msg = pr.gen_commit_message(filter_ghstack=True, ghstack_deps=[pr0, pr1, pr2])
1036
f"{pr.get_title()} (#106068)\n\n{RE_GHSTACK_DESC.sub('', pr.get_body())}\n"
1037
"Pull Request resolved: https://github.com/pytorch/pytorch/pull/106068\n"
1038
"Approved by: https://github.com/ezyang, https://github.com/fegin\n"
1039
"ghstack dependencies: #106032, #106033, #106034\n",
1043
reason="This test is run against a mutable PR that has changed, so it no longer works. The test should be changed"
1045
@mock.patch("trymerge.read_merge_rules")
1046
@mock.patch("trymerge.GitRepo")
1047
@mock.patch("trymerge.get_ghstack_prs")
1048
def test_merge_ghstack_into(
1050
mock_get_ghstack_prs: mock.MagicMock,
1051
mock_repo: mock.MagicMock,
1052
mock_merge_rules: mock.MagicMock,
1056
Test that the merge_ghstack_into method works correctly
1058
pr0 = GitHubPR("pytorch", "pytorch", 106032)
1059
pr1 = GitHubPR("pytorch", "pytorch", 106033)
1060
pr2 = GitHubPR("pytorch", "pytorch", 106034)
1061
pr = GitHubPR("pytorch", "pytorch", 106068)
1063
# note: in reverse order (e.g. self.pr is the last commit, top of the stack)
1064
mock_get_ghstack_prs.return_value = [
1071
mock_merge_rules.return_value = [
1073
"Mock title", patterns=["*"], approved_by=[], mandatory_checks_name=None
1077
mock_repo.cherry_pick.return_value = None
1078
mock_repo.amend_commit_message.return_value = None
1080
# Call the method under test
1081
res = pr.merge_ghstack_into(mock_repo, True)
1083
self.assertEqual(res, [pr2, pr])
1085
mock_repo.cherry_pick.assert_any_call("rev2")
1086
mock_repo.cherry_pick.assert_any_call("rev123")
1088
self.assertTrue(mock.call("rev1") not in mock_repo.cherry_pick.call_args_list)
1090
# Verify the first call
1091
message = mock_repo.amend_commit_message.call_args_list[0].args[0]
1093
"[FSDP] Optimize away intermediate `div_` for HSDP (#106034)\n\n\r\n"
1094
"### Background: Gradient Pre-Divide"
1097
"\nPull Request resolved: https://github.com/pytorch/pytorch/pull/106034\nApproved by: \nghstack "
1098
"dependencies: #106032, #106033\n"
1101
self.assertTrue(message.startswith(prefix))
1102
self.assertTrue(message.endswith(suffix))
1104
# Verify the second call
1105
mock_repo.amend_commit_message.assert_any_call(
1106
"[FSDP] Break up `_post_backward_hook` into smaller funcs (#106068)\n\n\n"
1107
"Differential Revision: ["
1108
"D47852461](https://our.internmc.facebook.com/intern/diff/D47852461)\n"
1109
"Pull Request resolved: "
1110
"https://github.com/pytorch/pytorch/pull/106068\n"
1112
"ghstack dependencies: #106032, #106033, #106034\n"
1116
if __name__ == "__main__":