pytorch

Форк
0
/
test_digraph.py 
131 строка · 3.6 Кб
1
# Owner(s): ["oncall: package/deploy"]
2

3
from torch.package._digraph import DiGraph
4
from torch.testing._internal.common_utils import run_tests
5

6

7
try:
8
    from .common import PackageTestCase
9
except ImportError:
10
    # Support the case where we run this file directly.
11
    from common import PackageTestCase
12

13

14
class TestDiGraph(PackageTestCase):
15
    """Test the DiGraph structure we use to represent dependencies in PackageExporter"""
16

17
    def test_successors(self):
18
        g = DiGraph()
19
        g.add_edge("foo", "bar")
20
        g.add_edge("foo", "baz")
21
        g.add_node("qux")
22

23
        self.assertIn("bar", list(g.successors("foo")))
24
        self.assertIn("baz", list(g.successors("foo")))
25
        self.assertEqual(len(list(g.successors("qux"))), 0)
26

27
    def test_predecessors(self):
28
        g = DiGraph()
29
        g.add_edge("foo", "bar")
30
        g.add_edge("foo", "baz")
31
        g.add_node("qux")
32

33
        self.assertIn("foo", list(g.predecessors("bar")))
34
        self.assertIn("foo", list(g.predecessors("baz")))
35
        self.assertEqual(len(list(g.predecessors("qux"))), 0)
36

37
    def test_successor_not_in_graph(self):
38
        g = DiGraph()
39
        with self.assertRaises(ValueError):
40
            g.successors("not in graph")
41

42
    def test_predecessor_not_in_graph(self):
43
        g = DiGraph()
44
        with self.assertRaises(ValueError):
45
            g.predecessors("not in graph")
46

47
    def test_node_attrs(self):
48
        g = DiGraph()
49
        g.add_node("foo", my_attr=1, other_attr=2)
50
        self.assertEqual(g.nodes["foo"]["my_attr"], 1)
51
        self.assertEqual(g.nodes["foo"]["other_attr"], 2)
52

53
    def test_node_attr_update(self):
54
        g = DiGraph()
55
        g.add_node("foo", my_attr=1)
56
        self.assertEqual(g.nodes["foo"]["my_attr"], 1)
57

58
        g.add_node("foo", my_attr="different")
59
        self.assertEqual(g.nodes["foo"]["my_attr"], "different")
60

61
    def test_edges(self):
62
        g = DiGraph()
63
        g.add_edge(1, 2)
64
        g.add_edge(2, 3)
65
        g.add_edge(1, 3)
66
        g.add_edge(4, 5)
67

68
        edge_list = list(g.edges)
69
        self.assertEqual(len(edge_list), 4)
70

71
        self.assertIn((1, 2), edge_list)
72
        self.assertIn((2, 3), edge_list)
73
        self.assertIn((1, 3), edge_list)
74
        self.assertIn((4, 5), edge_list)
75

76
    def test_iter(self):
77
        g = DiGraph()
78
        g.add_node(1)
79
        g.add_node(2)
80
        g.add_node(3)
81

82
        nodes = set()
83
        nodes.update(g)
84

85
        self.assertEqual(nodes, {1, 2, 3})
86

87
    def test_contains(self):
88
        g = DiGraph()
89
        g.add_node("yup")
90

91
        self.assertTrue("yup" in g)
92
        self.assertFalse("nup" in g)
93

94
    def test_contains_non_hashable(self):
95
        g = DiGraph()
96
        self.assertFalse([1, 2, 3] in g)
97

98
    def test_forward_closure(self):
99
        g = DiGraph()
100
        g.add_edge("1", "2")
101
        g.add_edge("2", "3")
102
        g.add_edge("5", "4")
103
        g.add_edge("4", "3")
104
        self.assertTrue(g.forward_transitive_closure("1") == {"1", "2", "3"})
105
        self.assertTrue(g.forward_transitive_closure("4") == {"4", "3"})
106

107
    def test_all_paths(self):
108
        g = DiGraph()
109
        g.add_edge("1", "2")
110
        g.add_edge("1", "7")
111
        g.add_edge("7", "8")
112
        g.add_edge("8", "3")
113
        g.add_edge("2", "3")
114
        g.add_edge("5", "4")
115
        g.add_edge("4", "3")
116

117
        result = g.all_paths("1", "3")
118
        # to get rid of indeterminism
119
        actual = {i.strip("\n") for i in result.split(";")[2:-1]}
120
        expected = {
121
            '"2" -> "3"',
122
            '"1" -> "7"',
123
            '"7" -> "8"',
124
            '"1" -> "2"',
125
            '"8" -> "3"',
126
        }
127
        self.assertEqual(actual, expected)
128

129

130
if __name__ == "__main__":
131
    run_tests()
132

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.