Since this thing is on the innerwebs I suppose it should have a
[python_utils.git] / collect / bst.py
index 72a3b7738b981878b9b08eaba67bca2b33314f4f..d39419494d3f482712f17e13a5ff6ce1e7c2ebcf 100644 (file)
@@ -1,6 +1,10 @@
 #!/usr/bin/env python3
 
-from typing import Any, Optional, List
+# © Copyright 2021-2022, Scott Gasch
+
+"""Binary search tree."""
+
+from typing import Any, Generator, List, Optional
 
 
 class Node(object):
@@ -9,8 +13,8 @@ class Node(object):
         Note: value can be anything as long as it is comparable.
         Check out @functools.total_ordering.
         """
-        self.left = None
-        self.right = None
+        self.left: Optional[Node] = None
+        self.right: Optional[Node] = None
         self.value = value
 
 
@@ -84,16 +88,16 @@ class BinarySearchTree(object):
         """Find helper"""
         if value == node.value:
             return node
-        elif (value < node.value and node.left is not None):
+        elif value < node.value and node.left is not None:
             return self._find(value, node.left)
-        elif (value > node.value and node.right is not None):
+        elif value > node.value and node.right is not None:
             return self._find(value, node.right)
         return None
 
-    def _parent_path(self, current: Node, target: Node):
+    def _parent_path(self, current: Optional[Node], target: Node) -> List[Optional[Node]]:
         if current is None:
             return [None]
-        ret = [current]
+        ret: List[Optional[Node]] = [current]
         if target.value == current.value:
             return ret
         elif target.value < current.value:
@@ -104,7 +108,7 @@ class BinarySearchTree(object):
             ret.extend(self._parent_path(current.right, target))
             return ret
 
-    def parent_path(self, node: Node) -> Optional[List[Node]]:
+    def parent_path(self, node: Node) -> List[Optional[Node]]:
         """Return a list of nodes representing the path from
         the tree's root to the node argument.  If the node does
         not exist in the tree for some reason, the last element
@@ -456,7 +460,7 @@ class BinarySearchTree(object):
             if node.right is not None:
                 yield from self._iterate_by_depth(node.right, depth - 1)
 
-    def iterate_nodes_by_depth(self, depth: int):
+    def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
         """
         Iterate only the leaf nodes in the tree.
 
@@ -518,13 +522,16 @@ class BinarySearchTree(object):
             return x
 
         path = self.parent_path(node)
+        assert path[-1] is not None
         assert path[-1] == node
         path = path[:-1]
         path.reverse()
         for ancestor in path:
+            assert ancestor is not None
             if node != ancestor.right:
                 return ancestor
             node = ancestor
+        raise Exception()
 
     def _depth(self, node: Node, sofar: int) -> int:
         depth_left = sofar + 1
@@ -569,7 +576,13 @@ class BinarySearchTree(object):
     def height(self):
         return self.depth()
 
-    def repr_traverse(self, padding: str, pointer: str, node: Node, has_right_sibling: bool) -> str:
+    def repr_traverse(
+        self,
+        padding: str,
+        pointer: str,
+        node: Optional[Node],
+        has_right_sibling: bool,
+    ) -> str:
         if node is not None:
             viz = f'\n{padding}{pointer}{node.value}'
             if has_right_sibling:
@@ -626,4 +639,5 @@ class BinarySearchTree(object):
 
 if __name__ == '__main__':
     import doctest
+
     doctest.testmod()