Since this thing is on the innerwebs I suppose it should have a
[python_utils.git] / collect / bst.py
index d3231eecaa22060d0fc47aca073e9e4f98a09310..d39419494d3f482712f17e13a5ff6ce1e7c2ebcf 100644 (file)
@@ -1,16 +1,24 @@
 #!/usr/bin/env python3
 
 #!/usr/bin/env python3
 
-from typing import Any, List, Optional
+# © Copyright 2021-2022, Scott Gasch
+
+"""Binary search tree."""
+
+from typing import Any, Generator, List, Optional
 
 
 class Node(object):
     def __init__(self, value: Any) -> None:
 
 
 class Node(object):
     def __init__(self, value: Any) -> None:
-        self.left = None
-        self.right = None
+        """
+        Note: value can be anything as long as it is comparable.
+        Check out @functools.total_ordering.
+        """
+        self.left: Optional[Node] = None
+        self.right: Optional[Node] = None
         self.value = value
 
 
         self.value = value
 
 
-class BinaryTree(object):
+class BinarySearchTree(object):
     def __init__(self):
         self.root = None
         self.count = 0
     def __init__(self):
         self.root = None
         self.count = 0
@@ -23,7 +31,7 @@ class BinaryTree(object):
         """
         Insert something into the tree.
 
         """
         Insert something into the tree.
 
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> t.insert(10)
         >>> t.insert(20)
         >>> t.insert(5)
         >>> t.insert(10)
         >>> t.insert(20)
         >>> t.insert(5)
@@ -60,7 +68,7 @@ class BinaryTree(object):
         Find an item in the tree and return its Node.  Returns
         None if the item is not in the tree.
 
         Find an item in the tree and return its Node.  Returns
         None if the item is not in the tree.
 
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> t[99]
 
         >>> t.insert(10)
         >>> t[99]
 
         >>> t.insert(10)
@@ -80,28 +88,77 @@ class BinaryTree(object):
         """Find helper"""
         if value == node.value:
             return node
         """Find helper"""
         if value == node.value:
             return node
-        elif (value < node.value and node.left is not None):
+        elif value < node.value and node.left is not None:
             return self._find(value, node.left)
             return self._find(value, node.left)
-        else:
-            assert value > node.value
-            if node.right is not None:
-                return self._find(value, node.right)
+        elif value > node.value and node.right is not None:
+            return self._find(value, node.right)
         return None
 
         return None
 
+    def _parent_path(self, current: Optional[Node], target: Node) -> List[Optional[Node]]:
+        if current is None:
+            return [None]
+        ret: List[Optional[Node]] = [current]
+        if target.value == current.value:
+            return ret
+        elif target.value < current.value:
+            ret.extend(self._parent_path(current.left, target))
+            return ret
+        else:
+            assert target.value > current.value
+            ret.extend(self._parent_path(current.right, target))
+            return ret
+
+    def parent_path(self, node: Node) -> List[Optional[Node]]:
+        """Return a list of nodes representing the path from
+        the tree's root to the node argument.  If the node does
+        not exist in the tree for some reason, the last element
+        on the path will be None but the path will indicate the
+        ancestor path of that node were it inserted.
+
+        >>> t = BinarySearchTree()
+        >>> t.insert(50)
+        >>> t.insert(75)
+        >>> t.insert(25)
+        >>> t.insert(12)
+        >>> t.insert(33)
+        >>> t.insert(4)
+        >>> t.insert(88)
+        >>> t
+        50
+        ├──25
+        │  ├──12
+        │  │  └──4
+        │  └──33
+        └──75
+           └──88
+
+        >>> n = t[4]
+        >>> for x in t.parent_path(n):
+        ...     print(x.value)
+        50
+        25
+        12
+        4
+
+        >>> del t[4]
+        >>> for x in t.parent_path(n):
+        ...     if x is not None:
+        ...         print(x.value)
+        ...     else:
+        ...         print(x)
+        50
+        25
+        12
+        None
+
+        """
+        return self._parent_path(self.root, node)
+
     def __delitem__(self, value: Any) -> bool:
         """
         Delete an item from the tree and preserve the BST property.
 
     def __delitem__(self, value: Any) -> bool:
         """
         Delete an item from the tree and preserve the BST property.
 
-                            50
-                           /  \
-                         25    75
-                        /     /  \
-                      22    66    85
-                     /
-                   13
-
-
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> t.insert(50)
         >>> t.insert(75)
         >>> t.insert(25)
         >>> t.insert(50)
         >>> t.insert(75)
         >>> t.insert(25)
@@ -109,6 +166,14 @@ class BinaryTree(object):
         >>> t.insert(22)
         >>> t.insert(13)
         >>> t.insert(85)
         >>> t.insert(22)
         >>> t.insert(13)
         >>> t.insert(85)
+        >>> t
+        50
+        ├──25
+        │  └──22
+        │     └──13
+        └──75
+           ├──66
+           └──85
 
         >>> for value in t.iterate_inorder():
         ...     print(value)
 
         >>> for value in t.iterate_inorder():
         ...     print(value)
@@ -149,6 +214,11 @@ class BinaryTree(object):
         50
         66
         85
         50
         66
         85
+        >>> t
+        50
+        ├──25
+        └──85
+           └──66
 
         >>> t.__delitem__(99)
         False
 
         >>> t.__delitem__(99)
         False
@@ -216,7 +286,7 @@ class BinaryTree(object):
         """
         Returns the count of items in the tree.
 
         """
         Returns the count of items in the tree.
 
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> len(t)
         0
         >>> t.insert(50)
         >>> len(t)
         0
         >>> t.insert(50)
@@ -270,7 +340,7 @@ class BinaryTree(object):
         """
         Yield the tree's items in a preorder traversal sequence.
 
         """
         Yield the tree's items in a preorder traversal sequence.
 
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> t.insert(50)
         >>> t.insert(75)
         >>> t.insert(25)
         >>> t.insert(50)
         >>> t.insert(75)
         >>> t.insert(25)
@@ -295,18 +365,28 @@ class BinaryTree(object):
         """
         Yield the tree's items in a preorder traversal sequence.
 
         """
         Yield the tree's items in a preorder traversal sequence.
 
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> t.insert(50)
         >>> t.insert(75)
         >>> t.insert(25)
         >>> t.insert(66)
         >>> t.insert(22)
         >>> t.insert(13)
         >>> t.insert(50)
         >>> t.insert(75)
         >>> t.insert(25)
         >>> t.insert(66)
         >>> t.insert(22)
         >>> t.insert(13)
+        >>> t.insert(24)
+        >>> t
+        50
+        ├──25
+        │  └──22
+        │     ├──13
+        │     └──24
+        └──75
+           └──66
 
         >>> for value in t.iterate_inorder():
         ...     print(value)
         13
         22
 
         >>> for value in t.iterate_inorder():
         ...     print(value)
         13
         22
+        24
         25
         50
         66
         25
         50
         66
@@ -320,7 +400,7 @@ class BinaryTree(object):
         """
         Yield the tree's items in a preorder traversal sequence.
 
         """
         Yield the tree's items in a preorder traversal sequence.
 
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> t.insert(50)
         >>> t.insert(75)
         >>> t.insert(25)
         >>> t.insert(50)
         >>> t.insert(75)
         >>> t.insert(25)
@@ -353,7 +433,7 @@ class BinaryTree(object):
         """
         Iterate only the leaf nodes in the tree.
 
         """
         Iterate only the leaf nodes in the tree.
 
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> t.insert(50)
         >>> t.insert(75)
         >>> t.insert(25)
         >>> t.insert(50)
         >>> t.insert(75)
         >>> t.insert(25)
@@ -380,11 +460,11 @@ class BinaryTree(object):
             if node.right is not None:
                 yield from self._iterate_by_depth(node.right, depth - 1)
 
             if node.right is not None:
                 yield from self._iterate_by_depth(node.right, depth - 1)
 
-    def iterate_nodes_by_depth(self, depth: int):
+    def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
         """
         Iterate only the leaf nodes in the tree.
 
         """
         Iterate only the leaf nodes in the tree.
 
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> t.insert(50)
         >>> t.insert(75)
         >>> t.insert(25)
         >>> t.insert(50)
         >>> t.insert(75)
         >>> t.insert(25)
@@ -405,6 +485,54 @@ class BinaryTree(object):
         if self.root is not None:
             yield from self._iterate_by_depth(self.root, depth)
 
         if self.root is not None:
             yield from self._iterate_by_depth(self.root, depth)
 
+    def get_next_node(self, node: Node) -> Node:
+        """
+        Given a tree node, get the next greater node in the tree.
+
+        >>> t = BinarySearchTree()
+        >>> t.insert(50)
+        >>> t.insert(75)
+        >>> t.insert(25)
+        >>> t.insert(66)
+        >>> t.insert(22)
+        >>> t.insert(13)
+        >>> t.insert(23)
+        >>> t
+        50
+        ├──25
+        │  └──22
+        │     ├──13
+        │     └──23
+        └──75
+           └──66
+
+        >>> n = t[23]
+        >>> t.get_next_node(n).value
+        25
+
+        >>> n = t[50]
+        >>> t.get_next_node(n).value
+        66
+
+        """
+        if node.right is not None:
+            x = node.right
+            while x.left is not None:
+                x = x.left
+            return x
+
+        path = self.parent_path(node)
+        assert path[-1] is not None
+        assert path[-1] == node
+        path = path[:-1]
+        path.reverse()
+        for ancestor in path:
+            assert ancestor is not None
+            if node != ancestor.right:
+                return ancestor
+            node = ancestor
+        raise Exception()
+
     def _depth(self, node: Node, sofar: int) -> int:
         depth_left = sofar + 1
         depth_right = sofar + 1
     def _depth(self, node: Node, sofar: int) -> int:
         depth_left = sofar + 1
         depth_right = sofar + 1
@@ -419,7 +547,7 @@ class BinaryTree(object):
         Returns the max height (depth) of the tree in plies (edge distance
         from root).
 
         Returns the max height (depth) of the tree in plies (edge distance
         from root).
 
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> t.depth()
         0
 
         >>> t.depth()
         0
 
@@ -448,9 +576,15 @@ class BinaryTree(object):
     def height(self):
         return self.depth()
 
     def height(self):
         return self.depth()
 
-    def repr_traverse(self, padding: str, pointer: str, node: Node, has_right_sibling: bool) -> str:
+    def repr_traverse(
+        self,
+        padding: str,
+        pointer: str,
+        node: Optional[Node],
+        has_right_sibling: bool,
+    ) -> str:
         if node is not None:
         if node is not None:
-            self.viz += f'\n{padding}{pointer}{node.value}'
+            viz = f'\n{padding}{pointer}{node.value}'
             if has_right_sibling:
                 padding += "│  "
             else:
             if has_right_sibling:
                 padding += "│  "
             else:
@@ -462,14 +596,16 @@ class BinaryTree(object):
             else:
                 pointer_left = "└──"
 
             else:
                 pointer_left = "└──"
 
-            self.repr_traverse(padding, pointer_left, node.left, node.right is not None)
-            self.repr_traverse(padding, pointer_right, node.right, False)
+            viz += self.repr_traverse(padding, pointer_left, node.left, node.right is not None)
+            viz += self.repr_traverse(padding, pointer_right, node.right, False)
+            return viz
+        return ""
 
     def __repr__(self):
         """
         Draw the tree in ASCII.
 
 
     def __repr__(self):
         """
         Draw the tree in ASCII.
 
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> t.insert(50)
         >>> t.insert(25)
         >>> t.insert(75)
         >>> t.insert(50)
         >>> t.insert(25)
         >>> t.insert(75)
@@ -489,18 +625,19 @@ class BinaryTree(object):
         if self.root is None:
             return ""
 
         if self.root is None:
             return ""
 
-        self.viz = f'{self.root.value}'
+        ret = f'{self.root.value}'
         pointer_right = "└──"
         if self.root.right is None:
             pointer_left = "└──"
         else:
             pointer_left = "├──"
 
         pointer_right = "└──"
         if self.root.right is None:
             pointer_left = "└──"
         else:
             pointer_left = "├──"
 
-        self.repr_traverse('', pointer_left, self.root.left, self.root.left is not None)
-        self.repr_traverse('', pointer_right, self.root.right, False)
-        return self.viz
+        ret += self.repr_traverse('', pointer_left, self.root.left, self.root.left is not None)
+        ret += self.repr_traverse('', pointer_right, self.root.right, False)
+        return ret
 
 
 if __name__ == '__main__':
     import doctest
 
 
 if __name__ == '__main__':
     import doctest
+
     doctest.testmod()
     doctest.testmod()