pytorch

Форк
0
/
_digraph.py 
173 строки · 5.5 Кб
1
from collections import deque
2
from typing import List, Set
3

4

5
class DiGraph:
6
    """Really simple unweighted directed graph data structure to track dependencies.
7

8
    The API is pretty much the same as networkx so if you add something just
9
    copy their API.
10
    """
11

12
    def __init__(self):
13
        # Dict of node -> dict of arbitrary attributes
14
        self._node = {}
15
        # Nested dict of node -> successor node -> nothing.
16
        # (didn't implement edge data)
17
        self._succ = {}
18
        # Nested dict of node -> predecessor node -> nothing.
19
        self._pred = {}
20

21
        # Keep track of the order in which nodes are added to
22
        # the graph.
23
        self._node_order = {}
24
        self._insertion_idx = 0
25

26
    def add_node(self, n, **kwargs):
27
        """Add a node to the graph.
28

29
        Args:
30
            n: the node. Can we any object that is a valid dict key.
31
            **kwargs: any attributes you want to attach to the node.
32
        """
33
        if n not in self._node:
34
            self._node[n] = kwargs
35
            self._succ[n] = {}
36
            self._pred[n] = {}
37
            self._node_order[n] = self._insertion_idx
38
            self._insertion_idx += 1
39
        else:
40
            self._node[n].update(kwargs)
41

42
    def add_edge(self, u, v):
43
        """Add an edge to graph between nodes ``u`` and ``v``
44

45
        ``u`` and ``v`` will be created if they do not already exist.
46
        """
47
        # add nodes
48
        self.add_node(u)
49
        self.add_node(v)
50

51
        # add the edge
52
        self._succ[u][v] = True
53
        self._pred[v][u] = True
54

55
    def successors(self, n):
56
        """Returns an iterator over successor nodes of n."""
57
        try:
58
            return iter(self._succ[n])
59
        except KeyError as e:
60
            raise ValueError(f"The node {n} is not in the digraph.") from e
61

62
    def predecessors(self, n):
63
        """Returns an iterator over predecessors nodes of n."""
64
        try:
65
            return iter(self._pred[n])
66
        except KeyError as e:
67
            raise ValueError(f"The node {n} is not in the digraph.") from e
68

69
    @property
70
    def edges(self):
71
        """Returns an iterator over all edges (u, v) in the graph"""
72
        for n, successors in self._succ.items():
73
            for succ in successors:
74
                yield n, succ
75

76
    @property
77
    def nodes(self):
78
        """Returns a dictionary of all nodes to their attributes."""
79
        return self._node
80

81
    def __iter__(self):
82
        """Iterate over the nodes."""
83
        return iter(self._node)
84

85
    def __contains__(self, n):
86
        """Returns True if ``n`` is a node in the graph, False otherwise."""
87
        try:
88
            return n in self._node
89
        except TypeError:
90
            return False
91

92
    def forward_transitive_closure(self, src: str) -> Set[str]:
93
        """Returns a set of nodes that are reachable from src"""
94

95
        result = set(src)
96
        working_set = deque(src)
97
        while len(working_set) > 0:
98
            cur = working_set.popleft()
99
            for n in self.successors(cur):
100
                if n not in result:
101
                    result.add(n)
102
                    working_set.append(n)
103
        return result
104

105
    def backward_transitive_closure(self, src: str) -> Set[str]:
106
        """Returns a set of nodes that are reachable from src in reverse direction"""
107

108
        result = set(src)
109
        working_set = deque(src)
110
        while len(working_set) > 0:
111
            cur = working_set.popleft()
112
            for n in self.predecessors(cur):
113
                if n not in result:
114
                    result.add(n)
115
                    working_set.append(n)
116
        return result
117

118
    def all_paths(self, src: str, dst: str):
119
        """Returns a subgraph rooted at src that shows all the paths to dst."""
120

121
        result_graph = DiGraph()
122
        # First compute forward transitive closure of src (all things reachable from src).
123
        forward_reachable_from_src = self.forward_transitive_closure(src)
124

125
        if dst not in forward_reachable_from_src:
126
            return result_graph
127

128
        # Second walk the reverse dependencies of dst, adding each node to
129
        # the output graph iff it is also present in forward_reachable_from_src.
130
        # we don't use backward_transitive_closures for optimization purposes
131
        working_set = deque(dst)
132
        while len(working_set) > 0:
133
            cur = working_set.popleft()
134
            for n in self.predecessors(cur):
135
                if n in forward_reachable_from_src:
136
                    result_graph.add_edge(n, cur)
137
                    # only explore further if its reachable from src
138
                    working_set.append(n)
139

140
        return result_graph.to_dot()
141

142
    def first_path(self, dst: str) -> List[str]:
143
        """Returns a list of nodes that show the first path that resulted in dst being added to the graph."""
144
        path = []
145

146
        while dst:
147
            path.append(dst)
148
            candidates = self._pred[dst].keys()
149
            dst, min_idx = "", None
150
            for candidate in candidates:
151
                idx = self._node_order.get(candidate, None)
152
                if idx is None:
153
                    break
154
                if min_idx is None or idx < min_idx:
155
                    min_idx = idx
156
                    dst = candidate
157

158
        return list(reversed(path))
159

160
    def to_dot(self) -> str:
161
        """Returns the dot representation of the graph.
162

163
        Returns:
164
            A dot representation of the graph.
165
        """
166
        edges = "\n".join(f'"{f}" -> "{t}";' for f, t in self.edges)
167
        return f"""\
168
digraph G {{
169
rankdir = LR;
170
node [shape=box];
171
{edges}
172
}}
173
"""
174

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.