pytorch

Форк
0
/
test_trymerge.py 
1117 строк · 41.7 Кб
1
#!/usr/bin/env python3
2
# Tests implemented in this file are relying on GitHub GraphQL APIs
3
# In order to avoid test flakiness, results of the queries
4
# are cached in gql_mocks.json
5
# PyTorch Lint workflow does not have GITHUB_TOKEN defined to avoid
6
# flakiness, so if you are making changes to merge_rules or
7
# GraphQL queries in trymerge.py, please make sure to delete `gql_mocks.json`
8
# And re-run the test locally with ones PAT
9

10
import gzip
11
import json
12
import os
13
import warnings
14
from hashlib import sha256
15
from typing import Any, Dict, List, Optional
16
from unittest import main, mock, skip, TestCase
17
from urllib.error import HTTPError
18

19
from github_utils import gh_graphql
20
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
21
from trymerge import (
22
    categorize_checks,
23
    DRCI_CHECKRUN_NAME,
24
    find_matching_merge_rule,
25
    get_classifications,
26
    get_drci_classifications,
27
    get_rockset_results,
28
    gh_get_team_members,
29
    GitHubPR,
30
    JobCheckState,
31
    main as trymerge_main,
32
    MandatoryChecksMissingError,
33
    MergeRule,
34
    RE_GHSTACK_DESC,
35
    read_merge_rules,
36
    remove_job_name_suffix,
37
    validate_revert,
38
)
39

40

41
if "GIT_REMOTE_URL" not in os.environ:
42
    os.environ["GIT_REMOTE_URL"] = "https://github.com/pytorch/pytorch"
43

44
GQL_MOCKS = "gql_mocks.json.gz"
45
ROCKSET_MOCKS = "rockset_mocks.json.gz"
46
DRCI_MOCKS = "drci_mocks.json.gz"
47

48

49
def mock_query(
50
    fallback_function: Any,
51
    file_name: str,
52
    key_function: Any,
53
    *args: Any,
54
) -> Any:
55
    gql_db_fname = os.path.join(os.path.dirname(__file__), file_name)
56

57
    def get_mocked_queries() -> Any:
58
        if not os.path.exists(gql_db_fname):
59
            return {}
60
        with gzip.open(gql_db_fname, encoding="utf-8", mode="rt") as f:
61
            return json.load(f)
62

63
    def save_mocked_queries(obj: Any) -> None:
64
        with gzip.open(gql_db_fname, encoding="utf-8", mode="wt") as f:
65
            json.dump(obj, f, indent=2)
66
            f.write("\n")
67

68
    key = key_function(*args)
69
    mocked_queries = get_mocked_queries()
70

71
    if key in mocked_queries:
72
        return mocked_queries[key]
73

74
    try:
75
        rc = fallback_function(*args)
76
    except HTTPError as err:
77
        if err.code == 401 or err.code == 403:
78
            err_msg = f"If you are seeing this message during workflow run, please make sure to update {file_name}"
79
            err_msg += f" locally, by deleting it and running {os.path.basename(__file__)} with"
80
            err_msg += " GitHub Personal Access Token passed via GITHUB_TOKEN,"
81
            err_msg += " the rockset api key passed via ROCKSET_API_KEY,"
82
            err_msg += " and drci api key passed via DRCI_BOT_KEY environment variables"
83
            if (
84
                os.getenv("GITHUB_TOKEN") is None
85
                or os.getenv("ROCKSET_API_KEY") is None
86
                or os.getenv("DRCI_BOT_KEY") is None
87
            ):
88
                err_msg = (
89
                    "Failed to update cached queries as GITHUB_TOKEN or ROCKSET_API_KEY or DRCI_BOT_KEY "
90
                    + "is not defined. "
91
                    + err_msg
92
                )
93
            raise RuntimeError(err_msg) from err
94
    mocked_queries[key] = rc
95

96
    save_mocked_queries(mocked_queries)
97

98
    return rc
99

100

101
def mocked_gh_graphql(query: str, **kwargs: Any) -> Any:
102
    def key_function(query: str, kwargs: Any) -> str:
103
        return f"query_sha={sha256(query.encode('utf-8')).hexdigest()} " + " ".join(
104
            [f"{k}={kwargs[k]}" for k in sorted(kwargs.keys())]
105
        )
106

107
    def gh_graphql_wrapper(query: str, kwargs: Any) -> Any:
108
        return gh_graphql(query, **kwargs)
109

110
    return mock_query(gh_graphql_wrapper, GQL_MOCKS, key_function, query, kwargs)
111

112

113
def mocked_rockset_results(head_sha: str, merge_base: str, num_retries: int = 3) -> Any:
114
    return mock_query(
115
        get_rockset_results,
116
        ROCKSET_MOCKS,
117
        lambda x, y: f"{x} {y}",
118
        head_sha,
119
        merge_base,
120
    )
121

122

123
def mocked_drci_classifications(pr_num: int, project: str, num_retries: int = 3) -> Any:
124
    return mock_query(
125
        get_drci_classifications,
126
        DRCI_MOCKS,
127
        lambda x, y: f"{x} {y}",
128
        pr_num,
129
        project,
130
    )
131

132

133
def mock_parse_args(revert: bool = False, force: bool = False) -> Any:
134
    class Object:
135
        def __init__(self) -> None:
136
            self.revert = revert
137
            self.force = force
138
            self.pr_num = 76123
139
            self.dry_run = True
140
            self.comment_id = 0
141
            self.reason = "this is for testing"
142
            self.ignore_current = False
143
            self.check_mergeability = False
144

145
    return Object()
146

147

148
def mock_remove_label(
149
    org: str, repo: str, pr_num: str, label: str, dry_run: bool
150
) -> None:
151
    pass
152

153

154
def mock_revert(
155
    repo: GitRepo,
156
    pr: GitHubPR,
157
    *,
158
    dry_run: bool = False,
159
    comment_id: Optional[int] = None,
160
    reason: Optional[str] = None,
161
) -> None:
162
    pass
163

164

165
def mock_merge(
166
    pr: GitHubPR,
167
    repo: GitRepo,
168
    dry_run: bool = False,
169
    skip_mandatory_checks: bool = False,
170
    comment_id: Optional[int] = None,
171
    timeout_minutes: int = 400,
172
    stale_pr_days: int = 3,
173
    ignore_current: bool = False,
174
) -> None:
175
    pass
176

177

178
def mock_gh_get_info() -> Any:
179
    return {
180
        "closed": False,
181
        "isCrossRepository": False,
182
        "headRefName": "foo",
183
        "baseRefName": "bar",
184
        "baseRepository": {"defaultBranchRef": {"name": "bar"}},
185
        "files": {"nodes": [], "pageInfo": {"hasNextPage": False}},
186
        "changedFiles": 0,
187
    }
188

189

190
def mocked_read_merge_rules_NE(repo: Any, org: str, project: str) -> List[MergeRule]:
191
    return [
192
        MergeRule(
193
            name="mock with nonexistent check",
194
            patterns=["*"],
195
            approved_by=[],
196
            mandatory_checks_name=["Lint", "Facebook CLA Check", "nonexistent"],
197
            ignore_flaky_failures=True,
198
        ),
199
    ]
200

201

202
def mocked_read_merge_rules(repo: Any, org: str, project: str) -> List[MergeRule]:
203
    return [
204
        MergeRule(
205
            name="super",
206
            patterns=["*"],
207
            approved_by=["pytorch/metamates", "ngimel"],
208
            mandatory_checks_name=[
209
                "Lint",
210
                "pull / linux-xenial-cuda11.3-py3.7-gcc7 / build",
211
            ],
212
            ignore_flaky_failures=True,
213
        ),
214
        MergeRule(
215
            name="xla",
216
            patterns=[".github/ci_commit_pins/xla.txt"],
217
            approved_by=["pytorchbot"],
218
            mandatory_checks_name=[
219
                "Lint",
220
                "EasyCLA",
221
                "pull / linux-focal-py3_8-clang9-xla / build",
222
                "pull / linux-focal-py3_8-clang9-xla / test (xla, 1, 1, linux.12xlarge)",
223
            ],
224
            ignore_flaky_failures=True,
225
        ),
226
    ]
227

228

229
def mocked_read_merge_rules_approvers(
230
    repo: Any, org: str, project: str
231
) -> List[MergeRule]:
232
    return [
233
        MergeRule(
234
            name="Core Reviewers",
235
            patterns=["*"],
236
            approved_by=["1", "2", "3", "4", "5", "6"],
237
            mandatory_checks_name=[
238
                "Lint",
239
                "pull",
240
            ],
241
        ),
242
        MergeRule(
243
            name="Core Maintainers",
244
            patterns=["*"],
245
            approved_by=["1", "2", "malfet"],
246
            mandatory_checks_name=[
247
                "Lint",
248
                "pull",
249
            ],
250
        ),
251
    ]
252

253

254
def mocked_read_merge_rules_raise(repo: Any, org: str, project: str) -> List[MergeRule]:
255
    raise RuntimeError("testing")
256

257

258
def xla_merge_rules(repo: Any, org: str, project: str) -> List[MergeRule]:
259
    return [
260
        MergeRule(
261
            name=" OSS CI / pytorchbot / XLA",
262
            patterns=[".github/ci_commit_pins/xla.txt"],
263
            approved_by=["pytorchbot"],
264
            mandatory_checks_name=[
265
                "Lint",
266
                "EasyCLA",
267
                "pull / linux-bionic-py3_8-clang8-xla / build",
268
                "pull / linux-bionic-py3_8-clang8-xla / test (xla, 1, 1, linux.4xlarge)",
269
                "inductor / cuda11.8-py3.10-gcc7-sm86 / test (inductor_torchbench_dynamic, 1, 1, linux.g5.4xlarge.nvidia.gpu)",
270
            ],
271
            ignore_flaky_failures=False,
272
        ),
273
    ]
274

275

276
def empty_rockset_results(head_sha: str, merge_base: str) -> List[Dict[str, Any]]:
277
    return []
278

279

280
class DummyGitRepo(GitRepo):
281
    def __init__(self) -> None:
282
        super().__init__(get_git_repo_dir(), get_git_remote_name())
283

284
    def commits_resolving_gh_pr(self, pr_num: int) -> List[str]:
285
        return ["FakeCommitSha"]
286

287
    def commit_message(self, ref: str) -> str:
288
        return "super awsome commit message"
289

290

291
@mock.patch("trymerge.get_rockset_results", side_effect=empty_rockset_results)
292
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
293
@mock.patch(
294
    "trymerge.get_drci_classifications", side_effect=mocked_drci_classifications
295
)
296
class TestTryMerge(TestCase):
297
    def test_merge_rules_valid(self, *args: Any) -> None:
298
        "Test that merge_rules.yaml can be parsed"
299
        repo = DummyGitRepo()
300
        merge_rules = read_merge_rules(repo, "pytorch", "pytorch")
301
        self.assertGreater(len(merge_rules), 1)
302

303
    @mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules)
304
    def test_match_rules(self, *args: Any) -> None:
305
        "Tests that PR passes merge rules"
306
        pr = GitHubPR("pytorch", "pytorch", 109999)
307
        repo = DummyGitRepo()
308
        self.assertTrue(find_matching_merge_rule(pr, repo) is not None)
309

310
    @mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules_raise)
311
    def test_read_merge_rules_fails(self, *args: Any) -> None:
312
        "Tests that PR fails to read the merge rules"
313
        pr = GitHubPR("pytorch", "pytorch", 77700)
314
        repo = DummyGitRepo()
315
        self.assertRaisesRegex(
316
            RuntimeError, "testing", lambda: find_matching_merge_rule(pr, repo)
317
        )
318

319
    @mock.patch(
320
        "trymerge.read_merge_rules", side_effect=mocked_read_merge_rules_approvers
321
    )
322
    def test_match_rules_approvers(self, *args: Any) -> None:
323
        "Tests that PR has the necessary approvers"
324
        repo = DummyGitRepo()
325

326
        pr = GitHubPR("pytorch", "pytorch", 115329)
327
        # Test that all potential approvers across all rules are listed if the
328
        # PR doesn't have one of them
329
        for mock_rule in ["Core Reviewers", "Core Maintainers"]:
330
            self.assertRaisesRegex(
331
                RuntimeError,
332
                mock_rule,
333
                lambda: find_matching_merge_rule(pr, repo),
334
            )
335

336
        pr = GitHubPR("pytorch", "pytorch", 115495)
337
        # Test that PR with the correct approvers doesn't raise any exception
338
        self.assertTrue(find_matching_merge_rule(pr, repo) is not None)
339

340
    @mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules)
341
    def test_lint_fails(self, *args: Any) -> None:
342
        "Tests that PR fails mandatory lint check"
343
        pr = GitHubPR("pytorch", "pytorch", 90791)
344
        repo = DummyGitRepo()
345
        self.assertRaises(RuntimeError, lambda: find_matching_merge_rule(pr, repo))
346

347
    def test_get_last_comment(self, *args: Any) -> None:
348
        "Tests that last comment can be fetched"
349
        pr = GitHubPR("pytorch", "pytorch", 71759)
350
        comment = pr.get_last_comment()
351
        self.assertEqual(comment.author_login, "github-actions")
352
        self.assertIsNone(comment.editor_login)
353
        self.assertTrue("You've committed this PR" in comment.body_text)
354

355
    def test_get_author_null(self, *args: Any) -> None:
356
        """Tests that PR author can be computed
357
        If reply contains NULL
358
        """
359
        pr = GitHubPR("pytorch", "pytorch", 71759)
360
        author = pr.get_author()
361
        self.assertTrue(author is not None)
362
        self.assertTrue("@" in author)
363
        self.assertTrue(pr.get_diff_revision() is None)
364

365
        # PR with multiple contributors, but creator id is not among authors
366
        pr = GitHubPR("pytorch", "pytorch", 75095)
367
        self.assertEqual(pr.get_pr_creator_login(), "mruberry")
368
        author = pr.get_author()
369
        self.assertTrue(author is not None)
370

371
    def test_large_diff(self, *args: Any) -> None:
372
        "Tests that PR with 100+ files can be fetched"
373
        pr = GitHubPR("pytorch", "pytorch", 73099)
374
        self.assertTrue(pr.get_changed_files_count() > 100)
375
        flist = pr.get_changed_files()
376
        self.assertEqual(len(flist), pr.get_changed_files_count())
377

378
    def test_internal_changes(self, *args: Any) -> None:
379
        "Tests that PR with internal changes is detected"
380
        pr = GitHubPR("pytorch", "pytorch", 110140)
381
        self.assertTrue(pr.has_internal_changes())
382

383
    def test_comments_pagination(self, *args: Any) -> None:
384
        "Tests that PR with 50+ comments can be fetched"
385
        pr = GitHubPR("pytorch", "pytorch", 31093)
386
        self.assertGreater(len(pr.get_comments()), 50)
387

388
    def test_gql_complexity(self, *args: Any) -> None:
389
        "Fetch comments and conclusions for PR with 60 commits"
390
        # Previous version of GrapQL query used to cause HTTP/502 error
391
        # see https://gist.github.com/malfet/9b93bc7eeddeaf1d84546efc4f0c577f
392
        pr = GitHubPR("pytorch", "pytorch", 68111)
393
        self.assertGreater(len(pr.get_comments()), 20)
394
        # NS(09/27/2023): GitHub seems to recycle older checkruns
395
        # https://github.com/pytorch/pytorch/pull/68111/checks shows 0 runs
396
        # self.assertGreater(len(pr.get_checkrun_conclusions()), 3)
397
        self.assertGreater(pr.get_commit_count(), 60)
398

399
    @skip("GitHub doesn't keep this data anymore")
400
    def test_gql_retrieve_checksuites(self, *args: Any) -> None:
401
        "Fetch comments and conclusions for PR with 60 commits"
402
        pr = GitHubPR("pytorch", "pytorch", 94787)
403
        self.assertEqual(len(pr.get_checkrun_conclusions()), 182)
404

405
    def test_team_members(self, *args: Any) -> None:
406
        "Test fetching team members works"
407
        dev_infra_team = gh_get_team_members("pytorch", "pytorch-dev-infra")
408
        self.assertGreater(len(dev_infra_team), 2)
409
        with self.assertWarns(Warning):
410
            non_existing_team = gh_get_team_members("pytorch", "qwertyuiop")
411
            self.assertEqual(len(non_existing_team), 0)
412

413
    def test_get_author_many_commits(self, *args: Any) -> None:
414
        """Tests that authors for all commits can be fetched"""
415
        pr = GitHubPR("pytorch", "pytorch", 76118)
416
        authors = pr.get_authors()
417
        self.assertGreater(pr.get_commit_count(), 100)
418
        self.assertGreater(len(authors), 50)
419
        self.assertTrue("@" in pr.get_author())
420

421
    @mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules_NE)
422
    def test_pending_status_check(self, *args: Any) -> None:
423
        """Tests that PR with nonexistent/pending status checks fails with the right reason."""
424
        pr = GitHubPR("pytorch", "pytorch", 76118)
425
        repo = DummyGitRepo()
426
        self.assertRaisesRegex(
427
            MandatoryChecksMissingError,
428
            ".*are pending/not yet run.*",
429
            lambda: find_matching_merge_rule(pr, repo),
430
        )
431

432
    def test_get_author_many_reviews(self, *args: Any) -> None:
433
        """Tests that all reviews can be fetched"""
434
        pr = GitHubPR("pytorch", "pytorch", 76123)
435
        approved_by = pr.get_approved_by()
436
        self.assertGreater(len(approved_by), 0)
437
        assert pr._reviews is not None  # to pacify mypy
438
        self.assertGreater(len(pr._reviews), 100)
439

440
    def get_co_authors(self, *args: Any) -> None:
441
        """Tests that co-authors are recognized"""
442
        pr = GitHubPR("pytorch", "pytorch", 118347)
443
        authors = pr.get_authors()
444
        self.assertIn("kit1980", authors)
445
        self.assertIn("Co-authored-by:", pr.gen_commit_message())
446

447
    def test_get_checkruns_many_runs(self, *args: Any) -> None:
448
        """Tests that all checkruns can be fetched"""
449
        pr = GitHubPR("pytorch", "pytorch", 105260)
450
        conclusions = pr.get_checkrun_conclusions()
451
        self.assertEqual(len(conclusions), 221)
452
        self.assertTrue(
453
            "pull / linux-docs / build-docs-cpp-false" in conclusions.keys()
454
        )
455

456
    def test_cancelled_gets_ignored(self, *args: Any) -> None:
457
        """Tests that cancelled workflow does not override existing successfull status"""
458
        pr = GitHubPR("pytorch", "pytorch", 110367)
459
        conclusions = pr.get_checkrun_conclusions()
460
        lint_checks = [name for name in conclusions.keys() if "Lint" in name]
461
        self.assertTrue(len(lint_checks) > 0)
462
        self.assertTrue(
463
            all(conclusions[name].status == "SUCCESS" for name in lint_checks)
464
        )
465

466
    def test_get_review_comment_by_id(self, *args: Any) -> None:
467
        """Tests that even if the comment requested was actually a review instead of a simple comment, we can still find it"""
468
        pr = GitHubPR("pytorch", "pytorch", 107070)
469
        review_comment_id = 1582767635
470
        comment = pr.get_comment_by_id(review_comment_id)
471
        self.assertIsNotNone(comment)
472

473
    @mock.patch("trymerge.gh_get_pr_info", return_value=mock_gh_get_info())
474
    @mock.patch("trymerge.parse_args", return_value=mock_parse_args(True, False))
475
    @mock.patch("trymerge.try_revert", side_effect=mock_revert)
476
    def test_main_revert(self, mock_revert: Any, *args: Any) -> None:
477
        trymerge_main()
478
        mock_revert.assert_called_once()
479

480
    @mock.patch("trymerge.gh_get_pr_info", return_value=mock_gh_get_info())
481
    @mock.patch("trymerge.parse_args", return_value=mock_parse_args(False, True))
482
    @mock.patch("trymerge.gh_remove_label", side_effect=mock_remove_label)
483
    @mock.patch("trymerge.merge", side_effect=mock_merge)
484
    def test_main_force(
485
        self, mock_merge: Any, mock_parse_args: Any, *args: Any
486
    ) -> None:
487
        trymerge_main()
488
        mock_merge.assert_called_once_with(
489
            mock.ANY,
490
            mock.ANY,
491
            dry_run=mock.ANY,
492
            skip_mandatory_checks=True,
493
            comment_id=mock.ANY,
494
            ignore_current=False,
495
        )
496

497
    @mock.patch("trymerge.gh_get_pr_info", return_value=mock_gh_get_info())
498
    @mock.patch("trymerge.parse_args", return_value=mock_parse_args(False, False))
499
    @mock.patch("trymerge.gh_remove_label", side_effect=mock_remove_label)
500
    @mock.patch("trymerge.merge", side_effect=mock_merge)
501
    def test_main_merge(self, mock_merge: Any, *args: Any) -> None:
502
        trymerge_main()
503
        mock_merge.assert_called_once_with(
504
            mock.ANY,
505
            mock.ANY,
506
            dry_run=mock.ANY,
507
            skip_mandatory_checks=False,
508
            comment_id=mock.ANY,
509
            ignore_current=False,
510
        )
511

512
    @mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules)
513
    def test_revert_rules(self, *args: Any) -> None:
514
        """Tests that reverts from collaborators are allowed"""
515
        pr = GitHubPR("pytorch", "pytorch", 79694)
516
        repo = DummyGitRepo()
517
        self.assertIsNotNone(validate_revert(repo, pr, comment_id=1189459845))
518

519
    def test_get_changed_files(self, *args: Any) -> None:
520
        """
521
        Tests that the list changed files in a PR doesn't include duplicates
522
        """
523
        pr = GitHubPR("pytorch", "pytorch", 95233)
524
        try:
525
            changed_files = pr.get_changed_files()
526
        except RuntimeError as error:
527
            self.fail(f"get_changed_files throws an exception: {error}")
528

529
        self.assertEqual(len(changed_files), pr.get_changed_files_count())
530

531
    def test_revert_codev_abandoned_diff_succeeds(self, *args: Any) -> None:
532
        pr = GitHubPR("pytorch", "pytorch", 100652)
533

534
        class GitRepoCoDev(DummyGitRepo):
535
            def commit_message(self, ref: str) -> str:
536
                return pr.get_body()
537

538
        repo = GitRepoCoDev()
539
        validate_revert(repo, pr, comment_id=1588195237)
540

541
    def test_pr_changed_submodule_detection(self, *args: Any) -> None:
542
        # Updates submodule during dev-cycle but reverts it later
543
        pr = GitHubPR("pytorch", "pytorch", 95045)
544
        self.assertEqual(pr.get_changed_submodules(), [])
545
        self.assertFalse(pr.has_invalid_submodule_updates())
546

547
        # PR updates ideep
548
        pr = GitHubPR("pytorch", "pytorch", 94939)
549
        self.assertEqual(pr.get_changed_submodules(), ["third_party/ideep"])
550
        self.assertTrue(pr.has_invalid_submodule_updates())
551

552
        # Automated submodule update
553
        pr = GitHubPR("pytorch", "pytorch", 91051)
554
        self.assertEqual(pr.get_changed_submodules(), ["third_party/kineto"])
555
        self.assertFalse(pr.has_invalid_submodule_updates())
556

557
    def test_remove_job_name_suffix(self, *args: Any) -> None:
558
        test_cases = [
559
            {
560
                "name": "linux-bionic-cuda12.1-py3.10-gcc9-sm86 / test (default, 1, 5, linux.g5.4xlarge.nvidia.gpu)",
561
                "expected": "linux-bionic-cuda12.1-py3.10-gcc9-sm86 / test (default)",
562
            },
563
            {
564
                "name": "android-emulator-build-test / build-and-test (default, 1, 1, ubuntu-20.04-16x)",
565
                "expected": "android-emulator-build-test / build-and-test (default)",
566
            },
567
            {
568
                "name": "linux-focal-rocm5.4.2-py3.8 / build",
569
                "expected": "linux-focal-rocm5.4.2-py3.8 / build",
570
            },
571
            {
572
                "name": "libtorch-cpu-shared-with-deps-release-build",
573
                "expected": "libtorch-cpu-shared-with-deps-release-build",
574
            },
575
            {
576
                "name": "manywheel-py3_8-cuda11_8-test / test",
577
                "expected": "manywheel-py3_8-cuda11_8-test / test",
578
            },
579
            {
580
                "name": "lintrunner / linux-job",
581
                "expected": "lintrunner / linux-job",
582
            },
583
            {
584
                "name": "Test `run_test.py` is usable without boto3/rockset",
585
                "expected": "Test `run_test.py` is usable without boto3/rockset",
586
            },
587
        ]
588

589
        for case in test_cases:
590
            self.assertEqual(case["expected"], remove_job_name_suffix(case["name"]))
591

592
    def test_get_merge_base(self, *args: Any) -> None:
593
        pr = GitHubPR("pytorch", "pytorch", 104121)
594

595
        mock_merge_base = "mocked-sha"
596
        with mock.patch(
597
            "trymerge.gh_fetch_merge_base", return_value=mock_merge_base
598
        ) as mocked_gh_fetch_merge_base:
599
            self.assertEqual(mock_merge_base, pr.get_merge_base())
600

601
            # Make sure that consecutive calls will use the same merge base instead of
602
            # making another query
603
            self.assertEqual(mock_merge_base, pr.get_merge_base())
604
            mocked_gh_fetch_merge_base.assert_called_once()
605

606

607
@mock.patch("trymerge.get_rockset_results", side_effect=mocked_rockset_results)
608
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
609
@mock.patch("trymerge.gh_fetch_merge_base", return_value="")
610
@mock.patch(
611
    "trymerge.get_drci_classifications", side_effect=mocked_drci_classifications
612
)
613
class TestBypassFailures(TestCase):
614
    def test_get_classifications(self, *args: Any) -> None:
615
        pr = GitHubPR("pytorch", "pytorch", 109584)
616
        checks = pr.get_checkrun_conclusions()
617
        checks = get_classifications(
618
            pr.pr_num,
619
            pr.project,
620
            checks,
621
            [],
622
        )
623
        self.assertTrue(
624
            checks[
625
                "pull / linux-focal-py3.11-clang10 / test (dynamo, 1, 2, linux.2xlarge)"
626
            ].classification
627
            == "BROKEN_TRUNK"
628
        )
629
        self.assertTrue(
630
            checks[
631
                "trunk / win-vs2019-cpu-py3 / test (default, 2, 3, windows.4xlarge.nonephemeral)"
632
            ].classification
633
            == "FLAKY"
634
        )
635
        self.assertTrue(
636
            checks[
637
                "pull / linux-jammy-py3.8-gcc11 / test (distributed, 1, 2, linux.2xlarge)"
638
            ].classification
639
            == "FLAKY"
640
        )
641
        self.assertTrue(
642
            checks[
643
                "pull / linux-focal-cuda11.8-py3.10-gcc9 / test (distributed, 1, 3, linux.8xlarge.nvidia.gpu)"
644
            ].classification
645
            == "FLAKY"
646
        )
647

648
        # Set the threshold larger or equal to the number of ok failures
649
        pending, failed, ignorable = categorize_checks(
650
            checks, list(checks.keys()), ok_failed_checks_threshold=6
651
        )
652
        self.assertTrue(len(pending) == 0)
653
        self.assertTrue(len(failed) == 0)
654
        self.assertTrue(len(ignorable["FLAKY"]) == 4)
655
        self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 2)
656

657
        # Not set any threshold, defaults to -1 to ignore all flaky and broken trunk failures
658
        pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
659
        self.assertTrue(len(pending) == 0)
660
        self.assertTrue(len(failed) == 0)
661
        self.assertTrue(len(ignorable["FLAKY"]) == 4)
662
        self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 2)
663

664
        # Set the threshold lower than the number of ok failures
665
        pending, failed, ignorable = categorize_checks(
666
            checks, list(checks.keys()), ok_failed_checks_threshold=1
667
        )
668
        self.assertTrue(len(pending) == 0)
669
        self.assertTrue(len(failed) == 6)
670
        self.assertTrue(len(ignorable["FLAKY"]) == 4)
671
        self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 2)
672

673
        # Set the threshold to 0 like when ignore_flaky_failures is on
674
        pending, failed, ignorable = categorize_checks(
675
            checks, list(checks.keys()), ok_failed_checks_threshold=1
676
        )
677
        self.assertTrue(len(pending) == 0)
678
        self.assertTrue(len(failed) == 6)
679
        self.assertTrue(len(ignorable["FLAKY"]) == 4)
680
        self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 2)
681

682
    def test_get_classifications_flaky_fullname(self, *args: Any) -> None:
683
        pr = GitHubPR("pytorch", "pytorch", 110362)
684
        checks = pr.get_checkrun_conclusions()
685
        checks = get_classifications(
686
            pr.pr_num,
687
            pr.project,
688
            checks,
689
            [],
690
        )
691
        pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
692
        self.assertTrue(len(pending) == 0)
693
        self.assertTrue(len(failed) == 0)
694
        self.assertTrue(len(ignorable["FLAKY"]) == 1)
695

696
    def test_get_classifications_invalid_cancel(self, *args: Any) -> None:
697
        pr = GitHubPR("pytorch", "pytorch", 110367)
698
        checks = pr.get_checkrun_conclusions()
699
        checks = get_classifications(
700
            pr.pr_num,
701
            pr.project,
702
            checks,
703
            [],
704
        )
705
        pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
706
        self.assertTrue(len(pending) == 0)
707
        self.assertTrue(len(failed) == 0)
708
        self.assertTrue(len(ignorable["FLAKY"]) == 0)
709
        self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 0)
710
        self.assertTrue(len(ignorable["UNSTABLE"]) == 3)
711

712
    def test_get_classifications_similar_failures(self, *args: Any) -> None:
713
        pr = GitHubPR("pytorch", "pytorch", 109750)
714
        checks = pr.get_checkrun_conclusions()
715
        checks = get_classifications(
716
            pr.pr_num,
717
            pr.project,
718
            checks,
719
            [],
720
        )
721
        pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
722
        self.assertTrue(len(pending) == 0)
723
        self.assertTrue(len(failed) == 0)
724
        self.assertTrue(len(ignorable["FLAKY"]) == 1)
725

726
    def test_get_classifications_unstable(self, *args: Any) -> None:
727
        pr = GitHubPR("pytorch", "pytorch", 104312)
728
        checks = pr.get_checkrun_conclusions()
729
        checks = get_classifications(
730
            pr.pr_num,
731
            pr.project,
732
            checks,
733
            [],
734
        )
735
        workflow_name = "linux-bionic-cuda12.1-py3.10-gcc9-bazel-test"
736
        job_name = "build-and-test (default, 1, 1, linux.4xlarge.nvidia.gpu, unstable)"
737
        self.assertTrue(
738
            checks[f"pull / {workflow_name} / {job_name}"].classification == "UNSTABLE"
739
        )
740
        pending, failed, ignorable = categorize_checks(
741
            checks, list(checks.keys()), ok_failed_checks_threshold=1
742
        )
743
        self.assertTrue(len(pending) == 0)
744
        self.assertTrue(len(failed) == 0)
745
        self.assertTrue(len(ignorable["UNSTABLE"]) == 1)
746

747
        # Add another test case where there is no unstable keyword in the job name, but
748
        # the job has already been marked as unstable
749
        pr = GitHubPR("pytorch", "executorch", 3318)
750
        checks = pr.get_checkrun_conclusions()
751
        checks = get_classifications(
752
            pr.pr_num,
753
            pr.project,
754
            checks,
755
            [],
756
        )
757
        print(checks)
758
        workflow_name = "test-llama-app"
759
        job_name = "mobile-job (android)"
760
        self.assertTrue(
761
            checks[f"Android / {workflow_name} / {job_name}"].classification
762
            == "UNSTABLE"
763
        )
764
        pending, failed, ignorable = categorize_checks(
765
            checks, list(checks.keys()), ok_failed_checks_threshold=1
766
        )
767
        self.assertTrue(len(pending) == 0)
768
        self.assertTrue(len(failed) == 0)
769
        self.assertTrue(len(ignorable["UNSTABLE"]) == 1)
770

771
    def test_get_classifications_broken_trunk(self, *args: Any) -> None:
772
        # The mock merge base is the actual value returned by gh_fetch_merge_base
773
        test_cases = [
774
            {
775
                # This PR had one broken trunk failure but it was run on a different shard
776
                # than the one on the base commit. This should still count as broken trunk
777
                "pr_num": 104214,
778
                "related_failure_count": 0,
779
                "flaky_or_broken_trunk": 1,
780
            },
781
            {
782
                # This PR had one broken trunk failure and it used ghstack
783
                "pr_num": 105145,
784
                "related_failure_count": 0,
785
                "flaky_or_broken_trunk": 1,
786
            },
787
            {
788
                # The failure on the merge base was retried successfully and
789
                # its conclusion changed from failure to success. We want to
790
                # keep the failure record from the merge base so that it can
791
                # be used to detect broken trunk
792
                "pr_num": 107160,
793
                "related_failure_count": 0,
794
                "flaky_or_broken_trunk": 1,
795
            },
796
            {
797
                # This PR used Dr.CI broken trunk classification
798
                "pr_num": 111253,
799
                "related_failure_count": 1,
800
                "flaky_or_broken_trunk": 1,
801
            },
802
        ]
803

804
        for case in test_cases:
805
            pr_num = case["pr_num"]
806
            related_failure_count = case["related_failure_count"]
807
            flaky_or_broken_trunk = case["flaky_or_broken_trunk"]
808

809
            pr = GitHubPR("pytorch", "pytorch", pr_num)
810
            checks = pr.get_checkrun_conclusions()
811
            checks = get_classifications(
812
                pr.pr_num,
813
                pr.project,
814
                checks,
815
                [],
816
            )
817

818
            pending, failed, _ = categorize_checks(checks, list(checks.keys()))
819
            self.assertTrue(len(pending) == 0)
820
            self.assertTrue(len(failed) == related_failure_count)
821

822
            # When the ok_failed_checks_threshold is set to 0, the broken trunk failure
823
            # won't be ignored
824
            pending, failed, _ = categorize_checks(
825
                checks, list(checks.keys()), ok_failed_checks_threshold=0
826
            )
827
            self.assertTrue(len(pending) == 0)
828
            self.assertTrue(
829
                len(failed) == flaky_or_broken_trunk + related_failure_count
830
            )
831

832
    def test_ignore_current(self, *args: Any) -> None:
833
        # Test various interactions of the failure classifier to ensure that ignore
834
        # current checks takes place after other classifications: flaky, unstable,
835
        # or broken trunk. Only actual new failures should be kept in the list of
836
        # ignore current checks to use to record force merge with actual failures
837
        flaky = "pull / linux-focal-cuda11.8-py3.10-gcc9 / test (distributed, 1, 3, linux.8xlarge.nvidia.gpu)"
838
        broken_trunk = (
839
            "pull / linux-focal-py3.11-clang10 / test (dynamo, 1, 2, linux.2xlarge)"
840
        )
841

842
        pr = GitHubPR("pytorch", "pytorch", 109584)
843
        checks = pr.get_checkrun_conclusions()
844

845
        # Known flaky failure takes precedence over ignore current (need to set the
846
        # merge base here to get the results from Rockset, and that categorize the
847
        # broken trunk failure too
848
        checks = get_classifications(
849
            pr.pr_num,
850
            pr.project,
851
            checks,
852
            [broken_trunk, flaky],
853
        )
854
        self.assertTrue(checks[flaky].classification == "FLAKY")
855
        self.assertTrue(checks[broken_trunk].classification == "BROKEN_TRUNK")
856
        _, failed, ignorable = categorize_checks(checks, list(checks.keys()))
857
        self.assertTrue(len(failed) == 0)
858
        self.assertTrue(len(ignorable["IGNORE_CURRENT_CHECK"]) == 0)
859
        self.assertTrue(len(ignorable["FLAKY"]) == 4)
860
        self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 2)
861

862
    def test_get_classifications_wrong_workflow_name(self, *args: Any) -> None:
863
        pr = GitHubPR("pytorch", "pytorch", 123104)
864
        checks = pr.get_checkrun_conclusions()
865

866
        check_name = "linux-binary-conda / conda-py3_8-cuda11_8-build / build"
867
        check_name_workflow_path = ".github/workflows/generated-linux-binary-conda-nightly.yml / conda-py3_8-cuda11_8-build / build"
868

869
        # Mock a check where the workflow name uses the full path
870
        checks[check_name_workflow_path] = JobCheckState(
871
            check_name_workflow_path,
872
            checks[check_name].url,
873
            checks[check_name].status,
874
            checks[check_name].classification,
875
            checks[check_name].job_id,
876
            checks[check_name].title,
877
            checks[check_name].summary,
878
        )
879
        del checks[check_name]
880

881
        checks = get_classifications(
882
            pr.pr_num,
883
            pr.project,
884
            checks,
885
            [],
886
        )
887
        pending, failed, ignorable = categorize_checks(
888
            checks,
889
            list(checks.keys()),
890
        )
891

892
        self.assertTrue(len(pending) == 0)
893
        self.assertTrue(len(failed) == 0)
894
        self.assertTrue(len(ignorable["FLAKY"]) == 1)
895
        self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 0)
896

897
    def test_ignore_failures_older_run_same_workflow(self, *args: Any) -> None:
898
        pr = GitHubPR("pytorch", "pytorch", 129013)
899
        checks = pr.get_checkrun_conclusions()
900
        checks = get_classifications(
901
            pr.pr_num,
902
            pr.project,
903
            checks,
904
            [],
905
        )
906
        pending, failed, ignorable = categorize_checks(
907
            checks,
908
            list(checks.keys()),
909
        )
910
        self.assertTrue(len(pending) == 0)
911
        self.assertTrue(len(failed) == 0)
912
        self.assertTrue(len(ignorable["FLAKY"]) == 2)
913
        self.assertTrue(len(ignorable["UNSTABLE"]) == 13)
914

915
    @mock.patch("trymerge.read_merge_rules", side_effect=xla_merge_rules)
916
    def test_dont_ignore_flaky_failures(self, *args: Any) -> None:
917
        """
918
        Regression test for https://github.com/pytorch/test-infra/issues/4126
919
        """
920
        pr = GitHubPR("pytorch", "pytorch", 105312)
921
        repo = DummyGitRepo()
922
        # Check that failure is classified as flaky but still raises exception
923
        with warnings.catch_warnings(record=True) as w, self.assertRaises(RuntimeError):
924
            rule = find_matching_merge_rule(pr, repo)
925
        self.assertEqual(len(w), 1)
926
        self.assertIn(
927
            "1 checks failed but were likely due flakiness or broken trunk",
928
            str(w[0].message),
929
        )
930

931

932
@mock.patch("trymerge.get_rockset_results", side_effect=mocked_rockset_results)
933
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
934
@mock.patch("trymerge.gh_fetch_merge_base", return_value="")
935
@mock.patch("trymerge.get_drci_classifications", return_value={})
936
class TestBypassFailuresOnSandCastle(TestCase):
937
    def test_get_classifications(self, *args: Any) -> None:
938
        pr = GitHubPR("pytorch", "pytorch", 111467)
939
        checks = pr.get_checkrun_conclusions()
940
        checks = get_classifications(
941
            pr.pr_num,
942
            pr.project,
943
            checks,
944
            [],
945
        )
946
        pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
947
        self.assertTrue(len(pending) == 0)
948
        self.assertTrue(len(failed) == 0)
949
        self.assertTrue(len(ignorable["FLAKY"]) == 1)
950
        self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 1)
951

952
    def test_get_classifications_drci_checkrun_not_found(self, *args: Any) -> None:
953
        pr = GitHubPR("pytorch", "pytorch", 111467)
954

955
        # No summary
956
        checks = pr.get_checkrun_conclusions()
957
        checks[DRCI_CHECKRUN_NAME] = JobCheckState(
958
            DRCI_CHECKRUN_NAME,
959
            "",
960
            "NEUTRAL",
961
            None,
962
            1,
963
            "",
964
            None,
965
        )
966
        checks = get_classifications(
967
            pr.pr_num,
968
            pr.project,
969
            checks,
970
            [],
971
        )
972
        pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
973
        self.assertTrue(len(pending) == 0)
974
        self.assertTrue(len(failed) == 2)
975

976
        # Empty summary
977
        checks = pr.get_checkrun_conclusions()
978
        checks[DRCI_CHECKRUN_NAME] = JobCheckState(
979
            DRCI_CHECKRUN_NAME,
980
            "",
981
            "NEUTRAL",
982
            None,
983
            1,
984
            "",
985
            "",
986
        )
987
        checks = get_classifications(
988
            pr.pr_num,
989
            pr.project,
990
            checks,
991
            [],
992
        )
993
        pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
994
        self.assertTrue(len(pending) == 0)
995
        self.assertTrue(len(failed) == 2)
996

997
        # No Dr.CI checkrun
998
        checks = pr.get_checkrun_conclusions()
999
        del checks[DRCI_CHECKRUN_NAME]
1000
        checks = get_classifications(
1001
            pr.pr_num,
1002
            pr.project,
1003
            checks,
1004
            [],
1005
        )
1006
        pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
1007
        self.assertTrue(len(pending) == 0)
1008
        self.assertTrue(len(failed) == 2)
1009

1010

1011
@mock.patch("trymerge.get_rockset_results", side_effect=mocked_rockset_results)
1012
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
1013
@mock.patch("trymerge.gh_fetch_merge_base", return_value="")
1014
@mock.patch(
1015
    "trymerge.get_drci_classifications", side_effect=mocked_drci_classifications
1016
)
1017
class TestGitHubPRGhstackDependencies(TestCase):
1018
    def test_pr_dependencies(self, *args: Any) -> None:
1019
        pr = GitHubPR("pytorch", "pytorch", 106068)
1020
        msg = pr.gen_commit_message(filter_ghstack=True)
1021
        self.assertEqual(
1022
            msg,
1023
            f"{pr.get_title()} (#106068)\n\n{RE_GHSTACK_DESC.sub('', pr.get_body())}\n"
1024
            "Pull Request resolved: https://github.com/pytorch/pytorch/pull/106068\n"
1025
            "Approved by: https://github.com/ezyang, https://github.com/fegin\n",
1026
        )
1027

1028
    def test_pr_dependencies_ghstack(self, *args: Any) -> None:
1029
        pr0 = GitHubPR("pytorch", "pytorch", 106032)
1030
        pr1 = GitHubPR("pytorch", "pytorch", 106033)
1031
        pr2 = GitHubPR("pytorch", "pytorch", 106034)
1032
        pr = GitHubPR("pytorch", "pytorch", 106068)
1033
        msg = pr.gen_commit_message(filter_ghstack=True, ghstack_deps=[pr0, pr1, pr2])
1034
        self.assertEqual(
1035
            msg,
1036
            f"{pr.get_title()} (#106068)\n\n{RE_GHSTACK_DESC.sub('', pr.get_body())}\n"
1037
            "Pull Request resolved: https://github.com/pytorch/pytorch/pull/106068\n"
1038
            "Approved by: https://github.com/ezyang, https://github.com/fegin\n"
1039
            "ghstack dependencies: #106032, #106033, #106034\n",
1040
        )
1041

1042
    @skip(
1043
        reason="This test is run against a mutable PR that has changed, so it no longer works. The test should be changed"
1044
    )
1045
    @mock.patch("trymerge.read_merge_rules")
1046
    @mock.patch("trymerge.GitRepo")
1047
    @mock.patch("trymerge.get_ghstack_prs")
1048
    def test_merge_ghstack_into(
1049
        self,
1050
        mock_get_ghstack_prs: mock.MagicMock,
1051
        mock_repo: mock.MagicMock,
1052
        mock_merge_rules: mock.MagicMock,
1053
        *args: Any,
1054
    ) -> None:
1055
        """
1056
        Test that the merge_ghstack_into method works correctly
1057
        """
1058
        pr0 = GitHubPR("pytorch", "pytorch", 106032)
1059
        pr1 = GitHubPR("pytorch", "pytorch", 106033)
1060
        pr2 = GitHubPR("pytorch", "pytorch", 106034)
1061
        pr = GitHubPR("pytorch", "pytorch", 106068)
1062

1063
        # note: in reverse order (e.g. self.pr is the last commit, top of the stack)
1064
        mock_get_ghstack_prs.return_value = [
1065
            (pr0, "rev0"),
1066
            (pr1, "rev1"),
1067
            (pr2, "rev2"),
1068
            (pr, "rev123"),
1069
        ]
1070

1071
        mock_merge_rules.return_value = [
1072
            MergeRule(
1073
                "Mock title", patterns=["*"], approved_by=[], mandatory_checks_name=None
1074
            )
1075
        ]
1076

1077
        mock_repo.cherry_pick.return_value = None
1078
        mock_repo.amend_commit_message.return_value = None
1079

1080
        # Call the method under test
1081
        res = pr.merge_ghstack_into(mock_repo, True)
1082

1083
        self.assertEqual(res, [pr2, pr])
1084

1085
        mock_repo.cherry_pick.assert_any_call("rev2")
1086
        mock_repo.cherry_pick.assert_any_call("rev123")
1087

1088
        self.assertTrue(mock.call("rev1") not in mock_repo.cherry_pick.call_args_list)
1089

1090
        # Verify the first call
1091
        message = mock_repo.amend_commit_message.call_args_list[0].args[0]
1092
        prefix = (
1093
            "[FSDP] Optimize away intermediate `div_` for HSDP (#106034)\n\n\r\n"
1094
            "### Background: Gradient Pre-Divide"
1095
        )
1096
        suffix = (
1097
            "\nPull Request resolved: https://github.com/pytorch/pytorch/pull/106034\nApproved by: \nghstack "
1098
            "dependencies: #106032, #106033\n"
1099
        )
1100

1101
        self.assertTrue(message.startswith(prefix))
1102
        self.assertTrue(message.endswith(suffix))
1103

1104
        # Verify the second call
1105
        mock_repo.amend_commit_message.assert_any_call(
1106
            "[FSDP] Break up `_post_backward_hook` into smaller funcs (#106068)\n\n\n"
1107
            "Differential Revision: ["
1108
            "D47852461](https://our.internmc.facebook.com/intern/diff/D47852461)\n"
1109
            "Pull Request resolved: "
1110
            "https://github.com/pytorch/pytorch/pull/106068\n"
1111
            "Approved by: \n"
1112
            "ghstack dependencies: #106032, #106033, #106034\n"
1113
        )
1114

1115

1116
if __name__ == "__main__":
1117
    main()
1118

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.