Adds a __repr__ to graph.
[pyutils.git] / src / pyutils / graph.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2023, Scott Gasch
4
5 """A simple graph class that can be optionally directed and weighted and
6 some operations on it."""
7
8
9 import math
10 from typing import Dict, Generator, List, Optional, Set, Tuple
11
12 from pyutils import list_utils
13 from pyutils.typez.typing import Numeric
14
15
16 class Graph(object):
17     def __init__(self, directed: bool = False):
18         """Constructs a new Graph object.
19
20         Args:
21             directed: are we modeling a directed graph?  See :meth:`add_edge`.
22
23         """
24         self.directed = directed
25         self.graph: Dict[str, Dict[str, Numeric]] = {}
26         self.dijkstra: Optional[Tuple[str, Dict[str, str], Dict[str, Numeric]]] = None
27
28     def add_vertex(self, vertex_id: str) -> bool:
29         """Adds a new vertex to the graph.
30
31         Args:
32             vertex_id: the unique identifier of the new vertex.
33
34         Returns:
35             True unless vertex_id is already in the graph.
36
37         >>> g = Graph()
38         >>> g.add_vertex('a')
39         True
40         >>> g.add_vertex('b')
41         True
42         >>> g.add_vertex('a')
43         False
44         >>> len(g.get_vertices())
45         2
46
47         """
48         if vertex_id not in self.graph:
49             self.graph[vertex_id] = {}
50             self.dijkstra = None
51             return True
52         return False
53
54     def add_edge(self, src: str, dest: str, weight: Numeric = 1) -> None:
55         """Adds a new (optionally weighted) edge between src and dest
56         vertexes.  If the graph is not directed (see c'tor) this also
57         adds a reciprocal edge with the same weight back from dest to
58         src too.
59
60         .. note::
61
62             If either or both of src and dest are not already added to
63             the graph, they are implicitly added by adding this edge.
64
65         Args:
66             src: the source vertex id
67             dest: the destination vertex id
68             weight: optionally, the weight of the edge(s) added
69
70         >>> g = Graph()
71         >>> g.add_edge('a', 'b')
72         >>> g.add_edge('b', 'c', weight=2)
73         >>> len(g.get_vertices())
74         3
75         >>> g.get_edges()
76         {'a': {'b': 1}, 'b': {'a': 1, 'c': 2}, 'c': {'b': 2}}
77
78         """
79         self.add_vertex(src)
80         self.add_vertex(dest)
81         self.graph[src][dest] = weight
82         if not self.directed:
83             self.graph[dest][src] = weight
84         self.dijkstra = None
85
86     def remove_edge(self, source: str, dest: str):
87         """Remove a previously added edge in the graph.  If the graph is
88         not directed (see :meth:`__init__`), also removes the reciprocal
89         edge from dest back to source.
90
91         .. note::
92
93             This method does not remove vertexes (unlinked or otherwise).
94
95         Args:
96             source: the source vertex of the edge to remove
97             dest: the destination vertex of the edge to remove
98
99         >>> g = Graph()
100         >>> g.add_edge('A', 'B')
101         >>> g.add_edge('B', 'C')
102         >>> g.get_edges()
103         {'A': {'B': 1}, 'B': {'A': 1, 'C': 1}, 'C': {'B': 1}}
104         >>> g.remove_edge('A', 'B')
105         >>> g.get_edges()
106         {'B': {'C': 1}, 'C': {'B': 1}}
107         """
108         del self.graph[source][dest]
109         if len(self.graph[source]) == 0:
110             del self.graph[source]
111         if not self.directed:
112             del self.graph[dest][source]
113             if len(self.graph[dest]) == 0:
114                 del self.graph[dest]
115         self.dijkstra = None
116
117     def get_vertices(self) -> List[str]:
118         """
119         Returns:
120             a list of the vertex ids in the graph.
121
122         >>> g = Graph()
123         >>> g.add_vertex('a')
124         True
125         >>> g.add_edge('b', 'c')
126         >>> g.get_vertices()
127         ['a', 'b', 'c']
128         """
129         return list(self.graph.keys())
130
131     def get_edges(self) -> Dict[str, Dict[str, Numeric]]:
132         """
133         Returns:
134             A dict whose keys are source vertexes and values
135             are dicts of destination vertexes with values describing the
136             weight of the edge from source to destination.
137
138         >>> g = Graph(directed=True)
139         >>> g.add_edge('a', 'b')
140         >>> g.add_edge('b', 'c', weight=2)
141         >>> len(g.get_vertices())
142         3
143         >>> g.get_edges()
144         {'a': {'b': 1}, 'b': {'c': 2}, 'c': {}}
145         """
146         return self.graph
147
148     def __repr__(self) -> str:
149         """
150         Returns:
151             A string representation of the graph in GraphViz format.
152
153         >>> g = Graph(directed=True)
154         >>> g.add_edge('a', 'b', weight=2)
155         >>> g.add_edge('b', 'a', weight=4)
156         >>> g.add_edge('a', 'c', weight=10)
157         >>> print(g)
158         digraph G {
159             node [shape=record];
160             a -> b  [weight=2]
161             a -> c  [weight=10]
162             b -> a  [weight=4]
163         }
164
165         >>> h = Graph(directed=False)
166         >>> h.add_edge('A', 'B')
167         >>> h.add_edge('B', 'C')
168         >>> h.add_edge('B', 'D')
169         >>> h.add_edge('D', 'A')
170         >>> print(h)
171         graph G {
172             node [shape=record];
173             A -- B  [weight=1]
174             A -- D  [weight=1]
175             B -- A  [weight=1]
176             B -- C  [weight=1]
177             B -- D  [weight=1]
178             C -- B  [weight=1]
179             D -- B  [weight=1]
180             D -- A  [weight=1]
181         }
182         """
183         if self.directed:
184             edge = '->'
185             out = 'digraph G {\n'
186         else:
187             edge = '--'
188             out = 'graph G {\n'
189         out += '    node [shape=record];\n'
190         edges = self.get_edges()
191         for src, dests in edges.items():
192             for dest, weight in dests.items():
193                 out += f'    {src} {edge} {dest}  [weight={weight}]\n'
194         out += '}\n'
195         return out.strip()
196
197     def _dfs(self, vertex: str, visited: Set[str]):
198         yield vertex
199         visited.add(vertex)
200         for neighbor in self.graph[vertex]:
201             if neighbor not in visited:
202                 yield from self._dfs(neighbor, visited)
203
204     def dfs(
205         self, starting_vertex: str, target: Optional[str] = None
206     ) -> Generator[str, None, None]:
207         """Performs a depth first traversal of the graph.
208
209         Args:
210             starting_vertex: The DFS starting point.
211             target: The vertex that, if found, indicates to halt.
212
213         Returns:
214             An ordered sequence of vertex ids visited by the traversal.
215
216         .. graphviz::
217
218             graph g {
219                 node [shape=record];
220                 A -- B -- D;
221                 A -- C -- D -- E -- F;
222                 F -- F;
223                 E -- G;
224             }
225
226         >>> g = Graph()
227         >>> g.add_edge('A', 'B')
228         >>> g.add_edge('A', 'C')
229         >>> g.add_edge('B', 'D')
230         >>> g.add_edge('C', 'D')
231         >>> g.add_edge('D', 'E')
232         >>> g.add_edge('E', 'F')
233         >>> g.add_edge('E', 'G')
234         >>> g.add_edge('F', 'F')
235         >>> for node in g.dfs('A'):
236         ...     print(node)
237         A
238         B
239         D
240         C
241         E
242         F
243         G
244
245         >>> for node in g.dfs('F', 'B'):
246         ...     print(node)
247         F
248         E
249         D
250         B
251         """
252         visited: Set[str] = set()
253         for node in self._dfs(starting_vertex, visited):
254             yield node
255             if node == target:
256                 return
257
258     def bfs(
259         self, starting_vertex: str, target: Optional[str] = None
260     ) -> Generator[str, None, None]:
261         """Performs a breadth first traversal of the graph.
262
263         Args:
264             starting_vertex: The BFS starting point.
265             target: The vertex that, if found, we should halt the search.
266
267         Returns:
268             An ordered sequence of vertex ids visited by the traversal.
269
270         .. graphviz::
271
272             graph g {
273                 node [shape=record];
274                 A -- B -- D;
275                 A -- C -- D -- E -- F;
276                 F -- F;
277                 E -- G;
278             }
279
280         >>> g = Graph()
281         >>> g.add_edge('A', 'B')
282         >>> g.add_edge('A', 'C')
283         >>> g.add_edge('B', 'D')
284         >>> g.add_edge('C', 'D')
285         >>> g.add_edge('D', 'E')
286         >>> g.add_edge('E', 'F')
287         >>> g.add_edge('E', 'G')
288         >>> g.add_edge('F', 'F')
289         >>> for node in g.bfs('A'):
290         ...     print(node)
291         A
292         B
293         C
294         D
295         E
296         F
297         G
298
299         >>> for node in g.bfs('F', 'G'):
300         ...     print(node)
301         F
302         E
303         D
304         G
305         """
306         todo = []
307         visited = set()
308
309         todo.append(starting_vertex)
310         visited.add(starting_vertex)
311
312         while todo:
313             vertex = todo.pop(0)
314             yield vertex
315             if vertex == target:
316                 return
317
318             neighbors = self.graph[vertex]
319             for neighbor in neighbors:
320                 if neighbor not in visited:
321                     todo.append(neighbor)
322                     visited.add(neighbor)
323
324     def _generate_dijkstra(self, source: str) -> None:
325         """Internal helper that runs Dijkstra's from source"""
326         unvisited_nodes = self.get_vertices()
327
328         shortest_path: Dict[str, Numeric] = {}
329         for node in unvisited_nodes:
330             shortest_path[node] = math.inf
331         shortest_path[source] = 0
332
333         previous_nodes: Dict[str, str] = {}
334         while unvisited_nodes:
335             current_min_node = None
336             for node in unvisited_nodes:
337                 if current_min_node is None:
338                     current_min_node = node
339                 elif shortest_path[node] < shortest_path[current_min_node]:
340                     current_min_node = node
341
342             assert current_min_node
343             neighbors = self.graph[current_min_node]
344             for neighbor in neighbors:
345                 tentative_value = (
346                     shortest_path[current_min_node]
347                     + self.graph[current_min_node][neighbor]
348                 )
349                 if tentative_value < shortest_path[neighbor]:
350                     shortest_path[neighbor] = tentative_value
351                     previous_nodes[neighbor] = current_min_node
352             unvisited_nodes.remove(current_min_node)
353         self.dijkstra = (source, previous_nodes, shortest_path)
354
355     def minimum_path_between(
356         self, source: str, dest: str
357     ) -> Tuple[Optional[Numeric], List[str]]:
358         """Compute the minimum path (lowest cost path) between source
359         and dest.
360
361         .. note::
362
363             This method runs Dijkstra's algorithm
364             (https://en.wikipedia.org/wiki/Dijkstra%27s_algorithm)
365             internally and caches the results.  Subsequent calls made with
366             the same source node before modifying the graph are less
367             expensive due to these cached intermediate results.
368
369         Returns:
370             A tuple containing the minimum distance of the path and the path itself.
371             If there is no path between the requested nodes, returns (None, []).
372
373         .. graphviz::
374
375             graph g {
376                 node [shape=record];
377                 A -- B [weight=3];
378                 B -- D;
379                 A -- C [weight=2];
380                 C -- D -- E -- F;
381                 F -- F;
382                 E -- G;
383                 H;
384             }
385
386         >>> g = Graph()
387         >>> g.add_edge('A', 'B', 3)
388         >>> g.add_edge('A', 'C', 2)
389         >>> g.add_edge('B', 'D')
390         >>> g.add_edge('C', 'D')
391         >>> g.add_edge('D', 'E')
392         >>> g.add_edge('E', 'F')
393         >>> g.add_edge('E', 'G')
394         >>> g.add_edge('F', 'F')
395         >>> g.add_vertex('H')
396         True
397         >>> g.minimum_path_between('A', 'D')
398         (3, ['A', 'C', 'D'])
399         >>> g.minimum_path_between('A', 'H')
400         (None, [])
401
402         """
403         if self.dijkstra is None or self.dijkstra[0] != source:
404             self._generate_dijkstra(source)
405
406         assert self.dijkstra
407         path = []
408         node: Optional[str] = dest
409         while node != source:
410             assert node
411             path.append(node)
412             node = self.dijkstra[1].get(node, None)
413             if node is None:
414                 return (None, [])
415         path.append(source)
416         path = path[::-1]
417
418         cost: Numeric = 0
419         for (a, b) in list_utils.ngrams(path, 2):
420             cost += self.graph[a][b]
421         return (cost, path)
422
423
424 if __name__ == "__main__":
425     import doctest
426
427     doctest.testmod()