Since this thing is on the innerwebs I suppose it should have a
[python_utils.git] / collect / bst.py
index 94570f49be8490b4656d2b4ea12185a44c636212..d39419494d3f482712f17e13a5ff6ce1e7c2ebcf 100644 (file)
@@ -1,16 +1,24 @@
 #!/usr/bin/env python3
 
-from typing import Any, Optional
+# © Copyright 2021-2022, Scott Gasch
+
+"""Binary search tree."""
+
+from typing import Any, Generator, List, Optional
 
 
 class Node(object):
     def __init__(self, value: Any) -> None:
-        self.left = None
-        self.right = None
+        """
+        Note: value can be anything as long as it is comparable.
+        Check out @functools.total_ordering.
+        """
+        self.left: Optional[Node] = None
+        self.right: Optional[Node] = None
         self.value = value
 
 
-class BinaryTree(object):
+class BinarySearchTree(object):
     def __init__(self):
         self.root = None
         self.count = 0
@@ -23,7 +31,7 @@ class BinaryTree(object):
         """
         Insert something into the tree.
 
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> t.insert(10)
         >>> t.insert(20)
         >>> t.insert(5)
@@ -60,7 +68,7 @@ class BinaryTree(object):
         Find an item in the tree and return its Node.  Returns
         None if the item is not in the tree.
 
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> t[99]
 
         >>> t.insert(10)
@@ -80,28 +88,77 @@ class BinaryTree(object):
         """Find helper"""
         if value == node.value:
             return node
-        elif (value < node.value and node.left is not None):
+        elif value < node.value and node.left is not None:
             return self._find(value, node.left)
-        else:
-            assert value > node.value
-            if node.right is not None:
-                return self._find(value, node.right)
+        elif value > node.value and node.right is not None:
+            return self._find(value, node.right)
         return None
 
+    def _parent_path(self, current: Optional[Node], target: Node) -> List[Optional[Node]]:
+        if current is None:
+            return [None]
+        ret: List[Optional[Node]] = [current]
+        if target.value == current.value:
+            return ret
+        elif target.value < current.value:
+            ret.extend(self._parent_path(current.left, target))
+            return ret
+        else:
+            assert target.value > current.value
+            ret.extend(self._parent_path(current.right, target))
+            return ret
+
+    def parent_path(self, node: Node) -> List[Optional[Node]]:
+        """Return a list of nodes representing the path from
+        the tree's root to the node argument.  If the node does
+        not exist in the tree for some reason, the last element
+        on the path will be None but the path will indicate the
+        ancestor path of that node were it inserted.
+
+        >>> t = BinarySearchTree()
+        >>> t.insert(50)
+        >>> t.insert(75)
+        >>> t.insert(25)
+        >>> t.insert(12)
+        >>> t.insert(33)
+        >>> t.insert(4)
+        >>> t.insert(88)
+        >>> t
+        50
+        ├──25
+        │  ├──12
+        │  │  └──4
+        │  └──33
+        └──75
+           └──88
+
+        >>> n = t[4]
+        >>> for x in t.parent_path(n):
+        ...     print(x.value)
+        50
+        25
+        12
+        4
+
+        >>> del t[4]
+        >>> for x in t.parent_path(n):
+        ...     if x is not None:
+        ...         print(x.value)
+        ...     else:
+        ...         print(x)
+        50
+        25
+        12
+        None
+
+        """
+        return self._parent_path(self.root, node)
+
     def __delitem__(self, value: Any) -> bool:
         """
         Delete an item from the tree and preserve the BST property.
 
-                            50
-                           /  \
-                         25    75
-                        /     /  \
-                      22    66    85
-                     /
-                   13
-
-
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> t.insert(50)
         >>> t.insert(75)
         >>> t.insert(25)
@@ -109,6 +166,14 @@ class BinaryTree(object):
         >>> t.insert(22)
         >>> t.insert(13)
         >>> t.insert(85)
+        >>> t
+        50
+        ├──25
+        │  └──22
+        │     └──13
+        └──75
+           ├──66
+           └──85
 
         >>> for value in t.iterate_inorder():
         ...     print(value)
@@ -120,8 +185,8 @@ class BinaryTree(object):
         75
         85
 
-        >>> t.__delitem__(22)
-        True
+        >>> del t[22]  # Note: bool result is discarded
+
         >>> for value in t.iterate_inorder():
         ...     print(value)
         13
@@ -149,6 +214,11 @@ class BinaryTree(object):
         50
         66
         85
+        >>> t
+        50
+        ├──25
+        └──85
+           └──66
 
         >>> t.__delitem__(99)
         False
@@ -216,7 +286,7 @@ class BinaryTree(object):
         """
         Returns the count of items in the tree.
 
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> len(t)
         0
         >>> t.insert(50)
@@ -270,7 +340,7 @@ class BinaryTree(object):
         """
         Yield the tree's items in a preorder traversal sequence.
 
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> t.insert(50)
         >>> t.insert(75)
         >>> t.insert(25)
@@ -295,18 +365,28 @@ class BinaryTree(object):
         """
         Yield the tree's items in a preorder traversal sequence.
 
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> t.insert(50)
         >>> t.insert(75)
         >>> t.insert(25)
         >>> t.insert(66)
         >>> t.insert(22)
         >>> t.insert(13)
+        >>> t.insert(24)
+        >>> t
+        50
+        ├──25
+        │  └──22
+        │     ├──13
+        │     └──24
+        └──75
+           └──66
 
         >>> for value in t.iterate_inorder():
         ...     print(value)
         13
         22
+        24
         25
         50
         66
@@ -320,7 +400,7 @@ class BinaryTree(object):
         """
         Yield the tree's items in a preorder traversal sequence.
 
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> t.insert(50)
         >>> t.insert(75)
         >>> t.insert(25)
@@ -353,7 +433,7 @@ class BinaryTree(object):
         """
         Iterate only the leaf nodes in the tree.
 
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> t.insert(50)
         >>> t.insert(75)
         >>> t.insert(25)
@@ -380,11 +460,11 @@ class BinaryTree(object):
             if node.right is not None:
                 yield from self._iterate_by_depth(node.right, depth - 1)
 
-    def iterate_nodes_by_depth(self, depth: int):
+    def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
         """
         Iterate only the leaf nodes in the tree.
 
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> t.insert(50)
         >>> t.insert(75)
         >>> t.insert(25)
@@ -405,6 +485,54 @@ class BinaryTree(object):
         if self.root is not None:
             yield from self._iterate_by_depth(self.root, depth)
 
+    def get_next_node(self, node: Node) -> Node:
+        """
+        Given a tree node, get the next greater node in the tree.
+
+        >>> t = BinarySearchTree()
+        >>> t.insert(50)
+        >>> t.insert(75)
+        >>> t.insert(25)
+        >>> t.insert(66)
+        >>> t.insert(22)
+        >>> t.insert(13)
+        >>> t.insert(23)
+        >>> t
+        50
+        ├──25
+        │  └──22
+        │     ├──13
+        │     └──23
+        └──75
+           └──66
+
+        >>> n = t[23]
+        >>> t.get_next_node(n).value
+        25
+
+        >>> n = t[50]
+        >>> t.get_next_node(n).value
+        66
+
+        """
+        if node.right is not None:
+            x = node.right
+            while x.left is not None:
+                x = x.left
+            return x
+
+        path = self.parent_path(node)
+        assert path[-1] is not None
+        assert path[-1] == node
+        path = path[:-1]
+        path.reverse()
+        for ancestor in path:
+            assert ancestor is not None
+            if node != ancestor.right:
+                return ancestor
+            node = ancestor
+        raise Exception()
+
     def _depth(self, node: Node, sofar: int) -> int:
         depth_left = sofar + 1
         depth_right = sofar + 1
@@ -419,7 +547,7 @@ class BinaryTree(object):
         Returns the max height (depth) of the tree in plies (edge distance
         from root).
 
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> t.depth()
         0
 
@@ -448,7 +576,68 @@ class BinaryTree(object):
     def height(self):
         return self.depth()
 
+    def repr_traverse(
+        self,
+        padding: str,
+        pointer: str,
+        node: Optional[Node],
+        has_right_sibling: bool,
+    ) -> str:
+        if node is not None:
+            viz = f'\n{padding}{pointer}{node.value}'
+            if has_right_sibling:
+                padding += "│  "
+            else:
+                padding += '   '
+
+            pointer_right = "└──"
+            if node.right is not None:
+                pointer_left = "├──"
+            else:
+                pointer_left = "└──"
+
+            viz += self.repr_traverse(padding, pointer_left, node.left, node.right is not None)
+            viz += self.repr_traverse(padding, pointer_right, node.right, False)
+            return viz
+        return ""
+
+    def __repr__(self):
+        """
+        Draw the tree in ASCII.
+
+        >>> t = BinarySearchTree()
+        >>> t.insert(50)
+        >>> t.insert(25)
+        >>> t.insert(75)
+        >>> t.insert(12)
+        >>> t.insert(33)
+        >>> t.insert(88)
+        >>> t.insert(55)
+        >>> t
+        50
+        ├──25
+        │  ├──12
+        │  └──33
+        └──75
+           ├──55
+           └──88
+        """
+        if self.root is None:
+            return ""
+
+        ret = f'{self.root.value}'
+        pointer_right = "└──"
+        if self.root.right is None:
+            pointer_left = "└──"
+        else:
+            pointer_left = "├──"
+
+        ret += self.repr_traverse('', pointer_left, self.root.left, self.root.left is not None)
+        ret += self.repr_traverse('', pointer_right, self.root.right, False)
+        return ret
+
 
 if __name__ == '__main__':
     import doctest
+
     doctest.testmod()