1
# Owner(s): ["oncall: package/deploy"]
3
from torch.package._digraph import DiGraph
4
from torch.testing._internal.common_utils import run_tests
8
from .common import PackageTestCase
10
# Support the case where we run this file directly.
11
from common import PackageTestCase
14
class TestDiGraph(PackageTestCase):
15
"""Test the DiGraph structure we use to represent dependencies in PackageExporter"""
17
def test_successors(self):
19
g.add_edge("foo", "bar")
20
g.add_edge("foo", "baz")
23
self.assertIn("bar", list(g.successors("foo")))
24
self.assertIn("baz", list(g.successors("foo")))
25
self.assertEqual(len(list(g.successors("qux"))), 0)
27
def test_predecessors(self):
29
g.add_edge("foo", "bar")
30
g.add_edge("foo", "baz")
33
self.assertIn("foo", list(g.predecessors("bar")))
34
self.assertIn("foo", list(g.predecessors("baz")))
35
self.assertEqual(len(list(g.predecessors("qux"))), 0)
37
def test_successor_not_in_graph(self):
39
with self.assertRaises(ValueError):
40
g.successors("not in graph")
42
def test_predecessor_not_in_graph(self):
44
with self.assertRaises(ValueError):
45
g.predecessors("not in graph")
47
def test_node_attrs(self):
49
g.add_node("foo", my_attr=1, other_attr=2)
50
self.assertEqual(g.nodes["foo"]["my_attr"], 1)
51
self.assertEqual(g.nodes["foo"]["other_attr"], 2)
53
def test_node_attr_update(self):
55
g.add_node("foo", my_attr=1)
56
self.assertEqual(g.nodes["foo"]["my_attr"], 1)
58
g.add_node("foo", my_attr="different")
59
self.assertEqual(g.nodes["foo"]["my_attr"], "different")
68
edge_list = list(g.edges)
69
self.assertEqual(len(edge_list), 4)
71
self.assertIn((1, 2), edge_list)
72
self.assertIn((2, 3), edge_list)
73
self.assertIn((1, 3), edge_list)
74
self.assertIn((4, 5), edge_list)
85
self.assertEqual(nodes, {1, 2, 3})
87
def test_contains(self):
91
self.assertTrue("yup" in g)
92
self.assertFalse("nup" in g)
94
def test_contains_non_hashable(self):
96
self.assertFalse([1, 2, 3] in g)
98
def test_forward_closure(self):
104
self.assertTrue(g.forward_transitive_closure("1") == {"1", "2", "3"})
105
self.assertTrue(g.forward_transitive_closure("4") == {"4", "3"})
107
def test_all_paths(self):
117
result = g.all_paths("1", "3")
118
# to get rid of indeterminism
119
actual = {i.strip("\n") for i in result.split(";")[2:-1]}
127
self.assertEqual(actual, expected)
130
if __name__ == "__main__":