Since this thing is on the innerwebs I suppose it should have a
[python_utils.git] / collect / bst.py
index b4d25b34a627797660362a16d6430fdcd6d7eceb..d39419494d3f482712f17e13a5ff6ce1e7c2ebcf 100644 (file)
@@ -1,12 +1,20 @@
 #!/usr/bin/env python3
 
-from typing import Any, Optional, List
+# © Copyright 2021-2022, Scott Gasch
+
+"""Binary search tree."""
+
+from typing import Any, Generator, List, Optional
 
 
 class Node(object):
     def __init__(self, value: Any) -> None:
-        self.left = None
-        self.right = None
+        """
+        Note: value can be anything as long as it is comparable.
+        Check out @functools.total_ordering.
+        """
+        self.left: Optional[Node] = None
+        self.right: Optional[Node] = None
         self.value = value
 
 
@@ -80,18 +88,16 @@ class BinarySearchTree(object):
         """Find helper"""
         if value == node.value:
             return node
-        elif (value < node.value and node.left is not None):
+        elif value < node.value and node.left is not None:
             return self._find(value, node.left)
-        else:
-            assert value > node.value
-            if node.right is not None:
-                return self._find(value, node.right)
+        elif value > node.value and node.right is not None:
+            return self._find(value, node.right)
         return None
 
-    def _parent_path(self, current: Node, target: Node):
+    def _parent_path(self, current: Optional[Node], target: Node) -> List[Optional[Node]]:
         if current is None:
             return [None]
-        ret = [current]
+        ret: List[Optional[Node]] = [current]
         if target.value == current.value:
             return ret
         elif target.value < current.value:
@@ -102,9 +108,12 @@ class BinarySearchTree(object):
             ret.extend(self._parent_path(current.right, target))
             return ret
 
-    def parent_path(self, node: Node) -> Optional[List[Node]]:
+    def parent_path(self, node: Node) -> List[Optional[Node]]:
         """Return a list of nodes representing the path from
-        the tree's root to the node argument.
+        the tree's root to the node argument.  If the node does
+        not exist in the tree for some reason, the last element
+        on the path will be None but the path will indicate the
+        ancestor path of that node were it inserted.
 
         >>> t = BinarySearchTree()
         >>> t.insert(50)
@@ -131,6 +140,17 @@ class BinarySearchTree(object):
         12
         4
 
+        >>> del t[4]
+        >>> for x in t.parent_path(n):
+        ...     if x is not None:
+        ...         print(x.value)
+        ...     else:
+        ...         print(x)
+        50
+        25
+        12
+        None
+
         """
         return self._parent_path(self.root, node)
 
@@ -138,15 +158,6 @@ class BinarySearchTree(object):
         """
         Delete an item from the tree and preserve the BST property.
 
-                            50
-                           /  \
-                         25    75
-                        /     /  \
-                      22    66    85
-                     /
-                   13
-
-
         >>> t = BinarySearchTree()
         >>> t.insert(50)
         >>> t.insert(75)
@@ -155,6 +166,14 @@ class BinarySearchTree(object):
         >>> t.insert(22)
         >>> t.insert(13)
         >>> t.insert(85)
+        >>> t
+        50
+        ├──25
+        │  └──22
+        │     └──13
+        └──75
+           ├──66
+           └──85
 
         >>> for value in t.iterate_inorder():
         ...     print(value)
@@ -195,6 +214,11 @@ class BinarySearchTree(object):
         50
         66
         85
+        >>> t
+        50
+        ├──25
+        └──85
+           └──66
 
         >>> t.__delitem__(99)
         False
@@ -436,7 +460,7 @@ class BinarySearchTree(object):
             if node.right is not None:
                 yield from self._iterate_by_depth(node.right, depth - 1)
 
-    def iterate_nodes_by_depth(self, depth: int):
+    def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
         """
         Iterate only the leaf nodes in the tree.
 
@@ -498,13 +522,16 @@ class BinarySearchTree(object):
             return x
 
         path = self.parent_path(node)
+        assert path[-1] is not None
         assert path[-1] == node
         path = path[:-1]
         path.reverse()
         for ancestor in path:
+            assert ancestor is not None
             if node != ancestor.right:
                 return ancestor
             node = ancestor
+        raise Exception()
 
     def _depth(self, node: Node, sofar: int) -> int:
         depth_left = sofar + 1
@@ -549,9 +576,15 @@ class BinarySearchTree(object):
     def height(self):
         return self.depth()
 
-    def repr_traverse(self, padding: str, pointer: str, node: Node, has_right_sibling: bool) -> str:
+    def repr_traverse(
+        self,
+        padding: str,
+        pointer: str,
+        node: Optional[Node],
+        has_right_sibling: bool,
+    ) -> str:
         if node is not None:
-            self.viz += f'\n{padding}{pointer}{node.value}'
+            viz = f'\n{padding}{pointer}{node.value}'
             if has_right_sibling:
                 padding += "│  "
             else:
@@ -563,8 +596,10 @@ class BinarySearchTree(object):
             else:
                 pointer_left = "└──"
 
-            self.repr_traverse(padding, pointer_left, node.left, node.right is not None)
-            self.repr_traverse(padding, pointer_right, node.right, False)
+            viz += self.repr_traverse(padding, pointer_left, node.left, node.right is not None)
+            viz += self.repr_traverse(padding, pointer_right, node.right, False)
+            return viz
+        return ""
 
     def __repr__(self):
         """
@@ -590,18 +625,19 @@ class BinarySearchTree(object):
         if self.root is None:
             return ""
 
-        self.viz = f'{self.root.value}'
+        ret = f'{self.root.value}'
         pointer_right = "└──"
         if self.root.right is None:
             pointer_left = "└──"
         else:
             pointer_left = "├──"
 
-        self.repr_traverse('', pointer_left, self.root.left, self.root.left is not None)
-        self.repr_traverse('', pointer_right, self.root.right, False)
-        return self.viz
+        ret += self.repr_traverse('', pointer_left, self.root.left, self.root.left is not None)
+        ret += self.repr_traverse('', pointer_right, self.root.right, False)
+        return ret
 
 
 if __name__ == '__main__':
     import doctest
+
     doctest.testmod()