More spring cleaning.
[pyutils.git] / src / pyutils / graph.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2023, Scott Gasch
4
5 """A simple graph class that can be optionally directed and weighted and
6 some operations on it."""
7
8
9 import math
10 from typing import Dict, Generator, List, Optional, Set, Tuple
11
12 from pyutils import list_utils
13 from pyutils.types.simple import Numeric
14
15
16 class Graph(object):
17     def __init__(self, directed: bool = False):
18         """Constructs a new Graph object.
19
20         Args:
21             directed: are we modeling a directed graph?  See :meth:`add_edge`.
22
23         """
24         self.directed = directed
25         self.graph: Dict[str, Dict[str, Numeric]] = {}
26         self.dijkstra: Optional[Tuple[str, Dict[str, str], Dict[str, Numeric]]] = None
27
28     def add_vertex(self, vertex_id: str) -> bool:
29         """Adds a new vertex to the graph.
30
31         Args:
32             vertex_id: the unique identifier of the new vertex.
33
34         Returns:
35             True unless vertex_id is already in the graph.
36
37         >>> g = Graph()
38         >>> g.add_vertex('a')
39         True
40         >>> g.add_vertex('b')
41         True
42         >>> g.add_vertex('a')
43         False
44         >>> len(g.get_vertices())
45         2
46
47         """
48         if vertex_id not in self.graph:
49             self.graph[vertex_id] = {}
50             self.dijkstra = None
51             return True
52         return False
53
54     def add_edge(self, src: str, dest: str, weight: Numeric = 1) -> None:
55         """Adds a new (optionally weighted) edge between src and dest
56         vertexes.  If the graph is not directed (see c'tor) this also
57         adds a reciprocal edge with the same weight back from dest to
58         src too.
59
60         .. note::
61
62             If either or both of src and dest are not already added to
63             the graph, they are implicitly added by adding this edge.
64
65         Args:
66             src: the source vertex id
67             dest: the destination vertex id
68             weight: optionally, the weight of the edge(s) added
69
70         >>> g = Graph()
71         >>> g.add_edge('a', 'b')
72         >>> g.add_edge('b', 'c', weight=2)
73         >>> len(g.get_vertices())
74         3
75         >>> g.get_edges()
76         {'a': {'b': 1}, 'b': {'a': 1, 'c': 2}, 'c': {'b': 2}}
77
78         """
79         self.add_vertex(src)
80         self.add_vertex(dest)
81         self.graph[src][dest] = weight
82         if not self.directed:
83             self.graph[dest][src] = weight
84         self.dijkstra = None
85
86     def remove_edge(self, source: str, dest: str):
87         """Remove a previously added edge in the graph.  If the graph is
88         not directed (see :meth:`__init__`), also removes the reciprocal
89         edge from dest back to source.
90
91         .. note::
92
93             This method does not remove vertexes (unlinked or otherwise).
94
95         Args:
96             source: the source vertex of the edge to remove
97             dest: the destination vertex of the edge to remove
98
99         >>> g = Graph()
100         >>> g.add_edge('A', 'B')
101         >>> g.add_edge('B', 'C')
102         >>> g.get_edges()
103         {'A': {'B': 1}, 'B': {'A': 1, 'C': 1}, 'C': {'B': 1}}
104         >>> g.remove_edge('A', 'B')
105         >>> g.get_edges()
106         {'B': {'C': 1}, 'C': {'B': 1}}
107         """
108         del self.graph[source][dest]
109         if len(self.graph[source]) == 0:
110             del self.graph[source]
111         if not self.directed:
112             del self.graph[dest][source]
113             if len(self.graph[dest]) == 0:
114                 del self.graph[dest]
115         self.dijkstra = None
116
117     def get_vertices(self) -> List[str]:
118         """
119         Returns:
120             a list of the vertex ids in the graph.
121
122         >>> g = Graph()
123         >>> g.add_vertex('a')
124         True
125         >>> g.add_edge('b', 'c')
126         >>> g.get_vertices()
127         ['a', 'b', 'c']
128         """
129         return list(self.graph.keys())
130
131     def get_edges(self) -> Dict[str, Dict[str, Numeric]]:
132         """
133         Returns:
134             A dict whose keys are source vertexes and values
135             are dicts of destination vertexes with values describing the
136             weight of the edge from source to destination.
137
138         >>> g = Graph(directed=True)
139         >>> g.add_edge('a', 'b')
140         >>> g.add_edge('b', 'c', weight=2)
141         >>> len(g.get_vertices())
142         3
143         >>> g.get_edges()
144         {'a': {'b': 1}, 'b': {'c': 2}, 'c': {}}
145         """
146         return self.graph
147
148     def _dfs(self, vertex: str, visited: Set[str]):
149         yield vertex
150         visited.add(vertex)
151         for neighbor in self.graph[vertex]:
152             if neighbor not in visited:
153                 yield from self._dfs(neighbor, visited)
154
155     def dfs(
156         self, starting_vertex: str, target: Optional[str] = None
157     ) -> Generator[str, None, None]:
158         """Performs a depth first traversal of the graph.
159
160         Args:
161             starting_vertex: The DFS starting point.
162             target: The vertex that, if found, indicates to halt.
163
164         Returns:
165             An ordered sequence of vertex ids visited by the traversal.
166
167         .. graphviz::
168
169             graph g {
170                 node [shape=record];
171                 A -- B -- D;
172                 A -- C -- D -- E -- F;
173                 F -- F;
174                 E -- G;
175             }
176
177         >>> g = Graph()
178         >>> g.add_edge('A', 'B')
179         >>> g.add_edge('A', 'C')
180         >>> g.add_edge('B', 'D')
181         >>> g.add_edge('C', 'D')
182         >>> g.add_edge('D', 'E')
183         >>> g.add_edge('E', 'F')
184         >>> g.add_edge('E', 'G')
185         >>> g.add_edge('F', 'F')
186         >>> for node in g.dfs('A'):
187         ...     print(node)
188         A
189         B
190         D
191         C
192         E
193         F
194         G
195
196         >>> for node in g.dfs('F', 'B'):
197         ...     print(node)
198         F
199         E
200         D
201         B
202         """
203         visited: Set[str] = set()
204         for node in self._dfs(starting_vertex, visited):
205             yield node
206             if node == target:
207                 return
208
209     def bfs(
210         self, starting_vertex: str, target: Optional[str] = None
211     ) -> Generator[str, None, None]:
212         """Performs a breadth first traversal of the graph.
213
214         Args:
215             starting_vertex: The BFS starting point.
216             target: The vertex that, if found, we should halt the search.
217
218         Returns:
219             An ordered sequence of vertex ids visited by the traversal.
220
221         .. graphviz::
222
223             graph g {
224                 node [shape=record];
225                 A -- B -- D;
226                 A -- C -- D -- E -- F;
227                 F -- F;
228                 E -- G;
229             }
230
231         >>> g = Graph()
232         >>> g.add_edge('A', 'B')
233         >>> g.add_edge('A', 'C')
234         >>> g.add_edge('B', 'D')
235         >>> g.add_edge('C', 'D')
236         >>> g.add_edge('D', 'E')
237         >>> g.add_edge('E', 'F')
238         >>> g.add_edge('E', 'G')
239         >>> g.add_edge('F', 'F')
240         >>> for node in g.bfs('A'):
241         ...     print(node)
242         A
243         B
244         C
245         D
246         E
247         F
248         G
249
250         >>> for node in g.bfs('F', 'G'):
251         ...     print(node)
252         F
253         E
254         D
255         G
256         """
257         todo = []
258         visited = set()
259
260         todo.append(starting_vertex)
261         visited.add(starting_vertex)
262
263         while todo:
264             vertex = todo.pop(0)
265             yield vertex
266             if vertex == target:
267                 return
268
269             neighbors = self.graph[vertex]
270             for neighbor in neighbors:
271                 if neighbor not in visited:
272                     todo.append(neighbor)
273                     visited.add(neighbor)
274
275     def _generate_dijkstra(self, source: str) -> None:
276         """Internal helper that runs Dijkstra's from source"""
277         unvisited_nodes = self.get_vertices()
278
279         shortest_path: Dict[str, Numeric] = {}
280         for node in unvisited_nodes:
281             shortest_path[node] = math.inf
282         shortest_path[source] = 0
283
284         previous_nodes: Dict[str, str] = {}
285         while unvisited_nodes:
286             current_min_node = None
287             for node in unvisited_nodes:
288                 if current_min_node is None:
289                     current_min_node = node
290                 elif shortest_path[node] < shortest_path[current_min_node]:
291                     current_min_node = node
292
293             assert current_min_node
294             neighbors = self.graph[current_min_node]
295             for neighbor in neighbors:
296                 tentative_value = (
297                     shortest_path[current_min_node]
298                     + self.graph[current_min_node][neighbor]
299                 )
300                 if tentative_value < shortest_path[neighbor]:
301                     shortest_path[neighbor] = tentative_value
302                     previous_nodes[neighbor] = current_min_node
303             unvisited_nodes.remove(current_min_node)
304         self.dijkstra = (source, previous_nodes, shortest_path)
305
306     def minimum_path_between(
307         self, source: str, dest: str
308     ) -> Tuple[Optional[Numeric], List[str]]:
309         """Compute the minimum path (lowest cost path) between source
310         and dest.
311
312         .. note::
313
314             This method runs Dijkstra's algorithm
315             (https://en.wikipedia.org/wiki/Dijkstra%27s_algorithm)
316             internally and caches the results.  Subsequent calls made with
317             the same source node before modifying the graph are less
318             expensive due to these cached intermediate results.
319
320         Returns:
321             A tuple containing the minimum distance of the path and the path itself.
322             If there is no path between the requested nodes, returns (None, []).
323
324         .. graphviz::
325
326             graph g {
327                 node [shape=record];
328                 A -- B [weight=3];
329                 B -- D;
330                 A -- C [weight=2];
331                 C -- D -- E -- F;
332                 F -- F;
333                 E -- G;
334                 H;
335             }
336
337         >>> g = Graph()
338         >>> g.add_edge('A', 'B', 3)
339         >>> g.add_edge('A', 'C', 2)
340         >>> g.add_edge('B', 'D')
341         >>> g.add_edge('C', 'D')
342         >>> g.add_edge('D', 'E')
343         >>> g.add_edge('E', 'F')
344         >>> g.add_edge('E', 'G')
345         >>> g.add_edge('F', 'F')
346         >>> g.add_vertex('H')
347         True
348         >>> g.minimum_path_between('A', 'D')
349         (3, ['A', 'C', 'D'])
350         >>> g.minimum_path_between('A', 'H')
351         (None, [])
352
353         """
354         if self.dijkstra is None or self.dijkstra[0] != source:
355             self._generate_dijkstra(source)
356
357         assert self.dijkstra
358         path = []
359         node: Optional[str] = dest
360         while node != source:
361             assert node
362             path.append(node)
363             node = self.dijkstra[1].get(node, None)
364             if node is None:
365                 return (None, [])
366         path.append(source)
367         path = path[::-1]
368
369         cost: Numeric = 0
370         for (a, b) in list_utils.ngrams(path, 2):
371             cost += self.graph[a][b]
372         return (cost, path)
373
374
375 if __name__ == "__main__":
376     import doctest
377
378     doctest.testmod()