Make subdirs type clean too.
[python_utils.git] / collect / bst.py
index 72a3b7738b981878b9b08eaba67bca2b33314f4f..9d6525946e8131728896d86f3400c38c5ba528e7 100644 (file)
@@ -1,6 +1,6 @@
 #!/usr/bin/env python3
 
-from typing import Any, Optional, List
+from typing import Any, Generator, List, Optional
 
 
 class Node(object):
@@ -9,8 +9,8 @@ class Node(object):
         Note: value can be anything as long as it is comparable.
         Check out @functools.total_ordering.
         """
-        self.left = None
-        self.right = None
+        self.left: Optional[Node] = None
+        self.right: Optional[Node] = None
         self.value = value
 
 
@@ -84,16 +84,18 @@ class BinarySearchTree(object):
         """Find helper"""
         if value == node.value:
             return node
-        elif (value < node.value and node.left is not None):
+        elif value < node.value and node.left is not None:
             return self._find(value, node.left)
-        elif (value > node.value and node.right is not None):
+        elif value > node.value and node.right is not None:
             return self._find(value, node.right)
         return None
 
-    def _parent_path(self, current: Node, target: Node):
+    def _parent_path(
+        self, current: Optional[Node], target: Node
+    ) -> List[Optional[Node]]:
         if current is None:
             return [None]
-        ret = [current]
+        ret: List[Optional[Node]] = [current]
         if target.value == current.value:
             return ret
         elif target.value < current.value:
@@ -104,7 +106,7 @@ class BinarySearchTree(object):
             ret.extend(self._parent_path(current.right, target))
             return ret
 
-    def parent_path(self, node: Node) -> Optional[List[Node]]:
+    def parent_path(self, node: Node) -> List[Optional[Node]]:
         """Return a list of nodes representing the path from
         the tree's root to the node argument.  If the node does
         not exist in the tree for some reason, the last element
@@ -456,7 +458,7 @@ class BinarySearchTree(object):
             if node.right is not None:
                 yield from self._iterate_by_depth(node.right, depth - 1)
 
-    def iterate_nodes_by_depth(self, depth: int):
+    def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
         """
         Iterate only the leaf nodes in the tree.
 
@@ -518,13 +520,16 @@ class BinarySearchTree(object):
             return x
 
         path = self.parent_path(node)
+        assert path[-1]
         assert path[-1] == node
         path = path[:-1]
         path.reverse()
         for ancestor in path:
+            assert ancestor
             if node != ancestor.right:
                 return ancestor
             node = ancestor
+        raise Exception()
 
     def _depth(self, node: Node, sofar: int) -> int:
         depth_left = sofar + 1
@@ -569,7 +574,9 @@ class BinarySearchTree(object):
     def height(self):
         return self.depth()
 
-    def repr_traverse(self, padding: str, pointer: str, node: Node, has_right_sibling: bool) -> str:
+    def repr_traverse(
+        self, padding: str, pointer: str, node: Optional[Node], has_right_sibling: bool
+    ) -> str:
         if node is not None:
             viz = f'\n{padding}{pointer}{node.value}'
             if has_right_sibling:
@@ -583,7 +590,9 @@ class BinarySearchTree(object):
             else:
                 pointer_left = "└──"
 
-            viz += self.repr_traverse(padding, pointer_left, node.left, node.right is not None)
+            viz += self.repr_traverse(
+                padding, pointer_left, node.left, node.right is not None
+            )
             viz += self.repr_traverse(padding, pointer_right, node.right, False)
             return viz
         return ""
@@ -619,11 +628,14 @@ class BinarySearchTree(object):
         else:
             pointer_left = "├──"
 
-        ret += self.repr_traverse('', pointer_left, self.root.left, self.root.left is not None)
+        ret += self.repr_traverse(
+            '', pointer_left, self.root.left, self.root.left is not None
+        )
         ret += self.repr_traverse('', pointer_right, self.root.right, False)
         return ret
 
 
 if __name__ == '__main__':
     import doctest
+
     doctest.testmod()