Adds a __repr__ to graph.
[pyutils.git] / src / pyutils / graph.py
index 903be145f153fb967a81a6f731ddd1e594d04415..3e08159dee861f37c88a2d00f353f496045f75e9 100644 (file)
@@ -1,6 +1,6 @@
 #!/usr/bin/env python3
 
 #!/usr/bin/env python3
 
-# © Copyright 2021-2022, Scott Gasch
+# © Copyright 2021-2023, Scott Gasch
 
 """A simple graph class that can be optionally directed and weighted and
 some operations on it."""
 
 """A simple graph class that can be optionally directed and weighted and
 some operations on it."""
@@ -10,7 +10,7 @@ import math
 from typing import Dict, Generator, List, Optional, Set, Tuple
 
 from pyutils import list_utils
 from typing import Dict, Generator, List, Optional, Set, Tuple
 
 from pyutils import list_utils
-from pyutils.types.simple import Numeric
+from pyutils.typez.typing import Numeric
 
 
 class Graph(object):
 
 
 class Graph(object):
@@ -145,6 +145,55 @@ class Graph(object):
         """
         return self.graph
 
         """
         return self.graph
 
+    def __repr__(self) -> str:
+        """
+        Returns:
+            A string representation of the graph in GraphViz format.
+
+        >>> g = Graph(directed=True)
+        >>> g.add_edge('a', 'b', weight=2)
+        >>> g.add_edge('b', 'a', weight=4)
+        >>> g.add_edge('a', 'c', weight=10)
+        >>> print(g)
+        digraph G {
+            node [shape=record];
+            a -> b  [weight=2]
+            a -> c  [weight=10]
+            b -> a  [weight=4]
+        }
+
+        >>> h = Graph(directed=False)
+        >>> h.add_edge('A', 'B')
+        >>> h.add_edge('B', 'C')
+        >>> h.add_edge('B', 'D')
+        >>> h.add_edge('D', 'A')
+        >>> print(h)
+        graph G {
+            node [shape=record];
+            A -- B  [weight=1]
+            A -- D  [weight=1]
+            B -- A  [weight=1]
+            B -- C  [weight=1]
+            B -- D  [weight=1]
+            C -- B  [weight=1]
+            D -- B  [weight=1]
+            D -- A  [weight=1]
+        }
+        """
+        if self.directed:
+            edge = '->'
+            out = 'digraph G {\n'
+        else:
+            edge = '--'
+            out = 'graph G {\n'
+        out += '    node [shape=record];\n'
+        edges = self.get_edges()
+        for src, dests in edges.items():
+            for dest, weight in dests.items():
+                out += f'    {src} {edge} {dest}  [weight={weight}]\n'
+        out += '}\n'
+        return out.strip()
+
     def _dfs(self, vertex: str, visited: Set[str]):
         yield vertex
         visited.add(vertex)
     def _dfs(self, vertex: str, visited: Set[str]):
         yield vertex
         visited.add(vertex)
@@ -303,7 +352,9 @@ class Graph(object):
             unvisited_nodes.remove(current_min_node)
         self.dijkstra = (source, previous_nodes, shortest_path)
 
             unvisited_nodes.remove(current_min_node)
         self.dijkstra = (source, previous_nodes, shortest_path)
 
-    def minimum_path_between(self, source: str, dest: str) -> Tuple[Numeric, List[str]]:
+    def minimum_path_between(
+        self, source: str, dest: str
+    ) -> Tuple[Optional[Numeric], List[str]]:
         """Compute the minimum path (lowest cost path) between source
         and dest.
 
         """Compute the minimum path (lowest cost path) between source
         and dest.
 
@@ -354,8 +405,9 @@ class Graph(object):
 
         assert self.dijkstra
         path = []
 
         assert self.dijkstra
         path = []
-        node = dest
+        node: Optional[str] = dest
         while node != source:
         while node != source:
+            assert node
             path.append(node)
             node = self.dijkstra[1].get(node, None)
             if node is None:
             path.append(node)
             node = self.dijkstra[1].get(node, None)
             if node is None: