I guess it's 2023 now...
[pyutils.git] / src / pyutils / graph.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2023, Scott Gasch
4
5 """A simple graph class that can be optionally directed and weighted and
6 some operations on it."""
7
8
9 import math
10 from typing import Dict, Generator, List, Optional, Set, Tuple
11
12 from pyutils import list_utils
13 from pyutils.types.simple import Numeric
14
15
16 class Graph(object):
17     def __init__(self, directed: bool = False):
18         """Constructs a new Graph object.
19
20         Args:
21             directed: are we modeling a directed graph?  See :meth:`add_edge`.
22
23         """
24         self.directed = directed
25         self.graph: Dict[str, Dict[str, Numeric]] = {}
26         self.dijkstra: Optional[Tuple[str, Dict[str, str], Dict[str, Numeric]]] = None
27
28     def add_vertex(self, vertex_id: str) -> bool:
29         """Adds a new vertex to the graph.
30
31         Args:
32             vertex_id: the unique identifier of the new vertex.
33
34         Returns:
35             True unless vertex_id is already in the graph.
36
37         >>> g = Graph()
38         >>> g.add_vertex('a')
39         True
40         >>> g.add_vertex('b')
41         True
42         >>> g.add_vertex('a')
43         False
44         >>> len(g.get_vertices())
45         2
46
47         """
48         if vertex_id not in self.graph:
49             self.graph[vertex_id] = {}
50             self.dijkstra = None
51             return True
52         return False
53
54     def add_edge(self, src: str, dest: str, weight: Numeric = 1) -> None:
55         """Adds a new (optionally weighted) edge between src and dest
56         vertexes.  If the graph is not directed (see c'tor) this also
57         adds a reciprocal edge with the same weight back from dest to
58         src too.
59
60         .. note::
61
62             If either or both of src and dest are not already added to
63             the graph, they are implicitly added by adding this edge.
64
65         Args:
66             src: the source vertex id
67             dest: the destination vertex id
68             weight: optionally, the weight of the edge(s) added
69
70         >>> g = Graph()
71         >>> g.add_edge('a', 'b')
72         >>> g.add_edge('b', 'c', weight=2)
73         >>> len(g.get_vertices())
74         3
75         >>> g.get_edges()
76         {'a': {'b': 1}, 'b': {'a': 1, 'c': 2}, 'c': {'b': 2}}
77
78         """
79         self.add_vertex(src)
80         self.add_vertex(dest)
81         self.graph[src][dest] = weight
82         if not self.directed:
83             self.graph[dest][src] = weight
84         self.dijkstra = None
85
86     def remove_edge(self, source: str, dest: str):
87         """Remove a previously added edge in the graph.  If the graph is
88         not directed (see :meth:`__init__`), also removes the reciprocal
89         edge from dest back to source.
90
91         .. note::
92
93             This method does not remove vertexes (unlinked or otherwise).
94
95         Args:
96             source: the source vertex of the edge to remove
97             dest: the destination vertex of the edge to remove
98
99         >>> g = Graph()
100         >>> g.add_edge('A', 'B')
101         >>> g.add_edge('B', 'C')
102         >>> g.get_edges()
103         {'A': {'B': 1}, 'B': {'A': 1, 'C': 1}, 'C': {'B': 1}}
104         >>> g.remove_edge('A', 'B')
105         >>> g.get_edges()
106         {'B': {'C': 1}, 'C': {'B': 1}}
107         """
108         del self.graph[source][dest]
109         if len(self.graph[source]) == 0:
110             del self.graph[source]
111         if not self.directed:
112             del self.graph[dest][source]
113             if len(self.graph[dest]) == 0:
114                 del self.graph[dest]
115         self.dijkstra = None
116
117     def get_vertices(self) -> List[str]:
118         """
119         Returns:
120             a list of the vertex ids in the graph.
121
122         >>> g = Graph()
123         >>> g.add_vertex('a')
124         True
125         >>> g.add_edge('b', 'c')
126         >>> g.get_vertices()
127         ['a', 'b', 'c']
128         """
129         return list(self.graph.keys())
130
131     def get_edges(self) -> Dict[str, Dict[str, Numeric]]:
132         """
133         Returns:
134             A dict whose keys are source vertexes and values
135             are dicts of destination vertexes with values describing the
136             weight of the edge from source to destination.
137
138         >>> g = Graph(directed=True)
139         >>> g.add_edge('a', 'b')
140         >>> g.add_edge('b', 'c', weight=2)
141         >>> len(g.get_vertices())
142         3
143         >>> g.get_edges()
144         {'a': {'b': 1}, 'b': {'c': 2}, 'c': {}}
145         """
146         return self.graph
147
148     def _dfs(self, vertex: str, visited: Set[str]):
149         yield vertex
150         visited.add(vertex)
151         for neighbor in self.graph[vertex]:
152             if neighbor not in visited:
153                 yield from self._dfs(neighbor, visited)
154
155     def dfs(
156         self, starting_vertex: str, target: Optional[str] = None
157     ) -> Generator[str, None, None]:
158         """Performs a depth first traversal of the graph.
159
160         Args:
161             starting_vertex: The DFS starting point.
162             target: The vertex that, if found, indicates to halt.
163
164         Returns:
165             An ordered sequence of vertex ids visited by the traversal.
166
167         .. graphviz::
168
169             graph g {
170                 node [shape=record];
171                 A -- B -- D;
172                 A -- C -- D -- E -- F;
173                 F -- F;
174                 E -- G;
175             }
176
177         >>> g = Graph()
178         >>> g.add_edge('A', 'B')
179         >>> g.add_edge('A', 'C')
180         >>> g.add_edge('B', 'D')
181         >>> g.add_edge('C', 'D')
182         >>> g.add_edge('D', 'E')
183         >>> g.add_edge('E', 'F')
184         >>> g.add_edge('E', 'G')
185         >>> g.add_edge('F', 'F')
186         >>> for node in g.dfs('A'):
187         ...     print(node)
188         A
189         B
190         D
191         C
192         E
193         F
194         G
195
196         >>> for node in g.dfs('F', 'B'):
197         ...     print(node)
198         F
199         E
200         D
201         B
202         """
203         visited: Set[str] = set()
204         for node in self._dfs(starting_vertex, visited):
205             yield node
206             if node == target:
207                 return
208
209     def bfs(
210         self, starting_vertex: str, target: Optional[str] = None
211     ) -> Generator[str, None, None]:
212         """Performs a breadth first traversal of the graph.
213
214         Args:
215             starting_vertex: The BFS starting point.
216             target: The vertex that, if found, we should halt the search.
217
218         Returns:
219             An ordered sequence of vertex ids visited by the traversal.
220
221         .. graphviz::
222
223             graph g {
224                 node [shape=record];
225                 A -- B -- D;
226                 A -- C -- D -- E -- F;
227                 F -- F;
228                 E -- G;
229             }
230
231         >>> g = Graph()
232         >>> g.add_edge('A', 'B')
233         >>> g.add_edge('A', 'C')
234         >>> g.add_edge('B', 'D')
235         >>> g.add_edge('C', 'D')
236         >>> g.add_edge('D', 'E')
237         >>> g.add_edge('E', 'F')
238         >>> g.add_edge('E', 'G')
239         >>> g.add_edge('F', 'F')
240         >>> for node in g.bfs('A'):
241         ...     print(node)
242         A
243         B
244         C
245         D
246         E
247         F
248         G
249
250         >>> for node in g.bfs('F', 'G'):
251         ...     print(node)
252         F
253         E
254         D
255         G
256         """
257         todo = []
258         visited = set()
259
260         todo.append(starting_vertex)
261         visited.add(starting_vertex)
262
263         while todo:
264             vertex = todo.pop(0)
265             yield vertex
266             if vertex == target:
267                 return
268
269             neighbors = self.graph[vertex]
270             for neighbor in neighbors:
271                 if neighbor not in visited:
272                     todo.append(neighbor)
273                     visited.add(neighbor)
274
275     def _generate_dijkstra(self, source: str) -> None:
276         """Internal helper that runs Dijkstra's from source"""
277         unvisited_nodes = self.get_vertices()
278
279         shortest_path: Dict[str, Numeric] = {}
280         for node in unvisited_nodes:
281             shortest_path[node] = math.inf
282         shortest_path[source] = 0
283
284         previous_nodes: Dict[str, str] = {}
285         while unvisited_nodes:
286             current_min_node = None
287             for node in unvisited_nodes:
288                 if current_min_node is None:
289                     current_min_node = node
290                 elif shortest_path[node] < shortest_path[current_min_node]:
291                     current_min_node = node
292
293             assert current_min_node
294             neighbors = self.graph[current_min_node]
295             for neighbor in neighbors:
296                 tentative_value = (
297                     shortest_path[current_min_node]
298                     + self.graph[current_min_node][neighbor]
299                 )
300                 if tentative_value < shortest_path[neighbor]:
301                     shortest_path[neighbor] = tentative_value
302                     previous_nodes[neighbor] = current_min_node
303             unvisited_nodes.remove(current_min_node)
304         self.dijkstra = (source, previous_nodes, shortest_path)
305
306     def minimum_path_between(self, source: str, dest: str) -> Tuple[Numeric, List[str]]:
307         """Compute the minimum path (lowest cost path) between source
308         and dest.
309
310         .. note::
311
312             This method runs Dijkstra's algorithm
313             (https://en.wikipedia.org/wiki/Dijkstra%27s_algorithm)
314             internally and caches the results.  Subsequent calls made with
315             the same source node before modifying the graph are less
316             expensive due to these cached intermediate results.
317
318         Returns:
319             A tuple containing the minimum distance of the path and the path itself.
320             If there is no path between the requested nodes, returns (None, []).
321
322         .. graphviz::
323
324             graph g {
325                 node [shape=record];
326                 A -- B [weight=3];
327                 B -- D;
328                 A -- C [weight=2];
329                 C -- D -- E -- F;
330                 F -- F;
331                 E -- G;
332                 H;
333             }
334
335         >>> g = Graph()
336         >>> g.add_edge('A', 'B', 3)
337         >>> g.add_edge('A', 'C', 2)
338         >>> g.add_edge('B', 'D')
339         >>> g.add_edge('C', 'D')
340         >>> g.add_edge('D', 'E')
341         >>> g.add_edge('E', 'F')
342         >>> g.add_edge('E', 'G')
343         >>> g.add_edge('F', 'F')
344         >>> g.add_vertex('H')
345         True
346         >>> g.minimum_path_between('A', 'D')
347         (3, ['A', 'C', 'D'])
348         >>> g.minimum_path_between('A', 'H')
349         (None, [])
350
351         """
352         if self.dijkstra is None or self.dijkstra[0] != source:
353             self._generate_dijkstra(source)
354
355         assert self.dijkstra
356         path = []
357         node = dest
358         while node != source:
359             path.append(node)
360             node = self.dijkstra[1].get(node, None)
361             if node is None:
362                 return (None, [])
363         path.append(source)
364         path = path[::-1]
365
366         cost: Numeric = 0
367         for (a, b) in list_utils.ngrams(path, 2):
368             cost += self.graph[a][b]
369         return (cost, path)
370
371
372 if __name__ == "__main__":
373     import doctest
374
375     doctest.testmod()