2
from pathlib import Path
3
from unittest import main, SkipTest, TestCase
7
are_ghstack_branches_in_sync,
15
BASE_DIR = Path(__file__).parent
18
class TestPeekableIterator(TestCase):
19
def test_iterator(self, input_: str = "abcdef") -> None:
20
iter_ = PeekableIterator(input_)
21
for idx, c in enumerate(iter_):
22
self.assertEqual(c, input_[idx])
24
def test_is_iterable(self) -> None:
25
from collections.abc import Iterator
27
iter_ = PeekableIterator("")
28
self.assertTrue(isinstance(iter_, Iterator))
30
def test_peek(self, input_: str = "abcdef") -> None:
31
iter_ = PeekableIterator(input_)
32
for idx, c in enumerate(iter_):
33
if idx + 1 < len(input_):
34
self.assertEqual(iter_.peek(), input_[idx + 1])
36
self.assertTrue(iter_.peek() is None)
39
class TestPattern(TestCase):
40
def test_double_asterisks(self) -> None:
42
"aten/src/ATen/native/**LinearAlgebra*",
44
patterns_re = patterns_to_regex(allowed_patterns)
46
"aten/src/ATen/native/LinearAlgebra.cpp",
47
"aten/src/ATen/native/cpu/LinearAlgebraKernel.cpp",
49
for filename in fnames:
50
self.assertTrue(patterns_re.match(filename))
53
class TestRetriesDecorator(TestCase):
54
def test_simple(self) -> None:
56
def foo(x: int, y: int) -> int:
59
self.assertEqual(foo(3, 4), 7)
61
def test_fails(self) -> None:
62
@retries_decorator(rc=0)
63
def foo(x: int, y: int) -> int:
66
self.assertEqual(foo("a", 4), 0)
69
class TestGitRepo(TestCase):
70
def setUp(self) -> None:
71
repo_dir = BASE_DIR.parent.parent.absolute()
72
if not (repo_dir / ".git").is_dir():
74
"Can't find git directory, make sure to run this test on real repo checkout"
76
self.repo = GitRepo(str(repo_dir))
78
def _skip_if_ref_does_not_exist(self, ref: str) -> None:
79
"""Skip test if ref is missing as stale branches are deleted with time"""
81
self.repo.show_ref(ref)
82
except RuntimeError as e:
83
raise SkipTest(f"Can't find head ref {ref} due to {str(e)}") from e
85
def test_compute_diff(self) -> None:
86
diff = self.repo.diff("HEAD")
88
self.assertEqual(len(sha), 64)
90
def test_ghstack_branches_in_sync(self) -> None:
91
head_ref = "gh/SS-JIA/206/head"
92
self._skip_if_ref_does_not_exist(head_ref)
93
self.assertTrue(are_ghstack_branches_in_sync(self.repo, head_ref))
95
def test_ghstack_branches_not_in_sync(self) -> None:
96
head_ref = "gh/clee2000/1/head"
97
self._skip_if_ref_does_not_exist(head_ref)
98
self.assertFalse(are_ghstack_branches_in_sync(self.repo, head_ref))
101
if __name__ == "__main__":