Some binary tree methods to support the unscramble progam's sparsefile
[python_utils.git] / collect / bst.py
index d3231eecaa22060d0fc47aca073e9e4f98a09310..b4d25b34a627797660362a16d6430fdcd6d7eceb 100644 (file)
@@ -1,6 +1,6 @@
 #!/usr/bin/env python3
 
 #!/usr/bin/env python3
 
-from typing import Any, List, Optional
+from typing import Any, Optional, List
 
 
 class Node(object):
 
 
 class Node(object):
@@ -10,7 +10,7 @@ class Node(object):
         self.value = value
 
 
         self.value = value
 
 
-class BinaryTree(object):
+class BinarySearchTree(object):
     def __init__(self):
         self.root = None
         self.count = 0
     def __init__(self):
         self.root = None
         self.count = 0
@@ -23,7 +23,7 @@ class BinaryTree(object):
         """
         Insert something into the tree.
 
         """
         Insert something into the tree.
 
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> t.insert(10)
         >>> t.insert(20)
         >>> t.insert(5)
         >>> t.insert(10)
         >>> t.insert(20)
         >>> t.insert(5)
@@ -60,7 +60,7 @@ class BinaryTree(object):
         Find an item in the tree and return its Node.  Returns
         None if the item is not in the tree.
 
         Find an item in the tree and return its Node.  Returns
         None if the item is not in the tree.
 
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> t[99]
 
         >>> t.insert(10)
         >>> t[99]
 
         >>> t.insert(10)
@@ -88,6 +88,52 @@ class BinaryTree(object):
                 return self._find(value, node.right)
         return None
 
                 return self._find(value, node.right)
         return None
 
+    def _parent_path(self, current: Node, target: Node):
+        if current is None:
+            return [None]
+        ret = [current]
+        if target.value == current.value:
+            return ret
+        elif target.value < current.value:
+            ret.extend(self._parent_path(current.left, target))
+            return ret
+        else:
+            assert target.value > current.value
+            ret.extend(self._parent_path(current.right, target))
+            return ret
+
+    def parent_path(self, node: Node) -> Optional[List[Node]]:
+        """Return a list of nodes representing the path from
+        the tree's root to the node argument.
+
+        >>> t = BinarySearchTree()
+        >>> t.insert(50)
+        >>> t.insert(75)
+        >>> t.insert(25)
+        >>> t.insert(12)
+        >>> t.insert(33)
+        >>> t.insert(4)
+        >>> t.insert(88)
+        >>> t
+        50
+        ├──25
+        │  ├──12
+        │  │  └──4
+        │  └──33
+        └──75
+           └──88
+
+        >>> n = t[4]
+        >>> for x in t.parent_path(n):
+        ...     print(x.value)
+        50
+        25
+        12
+        4
+
+        """
+        return self._parent_path(self.root, node)
+
     def __delitem__(self, value: Any) -> bool:
         """
         Delete an item from the tree and preserve the BST property.
     def __delitem__(self, value: Any) -> bool:
         """
         Delete an item from the tree and preserve the BST property.
@@ -101,7 +147,7 @@ class BinaryTree(object):
                    13
 
 
                    13
 
 
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> t.insert(50)
         >>> t.insert(75)
         >>> t.insert(25)
         >>> t.insert(50)
         >>> t.insert(75)
         >>> t.insert(25)
@@ -216,7 +262,7 @@ class BinaryTree(object):
         """
         Returns the count of items in the tree.
 
         """
         Returns the count of items in the tree.
 
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> len(t)
         0
         >>> t.insert(50)
         >>> len(t)
         0
         >>> t.insert(50)
@@ -270,7 +316,7 @@ class BinaryTree(object):
         """
         Yield the tree's items in a preorder traversal sequence.
 
         """
         Yield the tree's items in a preorder traversal sequence.
 
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> t.insert(50)
         >>> t.insert(75)
         >>> t.insert(25)
         >>> t.insert(50)
         >>> t.insert(75)
         >>> t.insert(25)
@@ -295,18 +341,28 @@ class BinaryTree(object):
         """
         Yield the tree's items in a preorder traversal sequence.
 
         """
         Yield the tree's items in a preorder traversal sequence.
 
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> t.insert(50)
         >>> t.insert(75)
         >>> t.insert(25)
         >>> t.insert(66)
         >>> t.insert(22)
         >>> t.insert(13)
         >>> t.insert(50)
         >>> t.insert(75)
         >>> t.insert(25)
         >>> t.insert(66)
         >>> t.insert(22)
         >>> t.insert(13)
+        >>> t.insert(24)
+        >>> t
+        50
+        ├──25
+        │  └──22
+        │     ├──13
+        │     └──24
+        └──75
+           └──66
 
         >>> for value in t.iterate_inorder():
         ...     print(value)
         13
         22
 
         >>> for value in t.iterate_inorder():
         ...     print(value)
         13
         22
+        24
         25
         50
         66
         25
         50
         66
@@ -320,7 +376,7 @@ class BinaryTree(object):
         """
         Yield the tree's items in a preorder traversal sequence.
 
         """
         Yield the tree's items in a preorder traversal sequence.
 
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> t.insert(50)
         >>> t.insert(75)
         >>> t.insert(25)
         >>> t.insert(50)
         >>> t.insert(75)
         >>> t.insert(25)
@@ -353,7 +409,7 @@ class BinaryTree(object):
         """
         Iterate only the leaf nodes in the tree.
 
         """
         Iterate only the leaf nodes in the tree.
 
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> t.insert(50)
         >>> t.insert(75)
         >>> t.insert(25)
         >>> t.insert(50)
         >>> t.insert(75)
         >>> t.insert(25)
@@ -384,7 +440,7 @@ class BinaryTree(object):
         """
         Iterate only the leaf nodes in the tree.
 
         """
         Iterate only the leaf nodes in the tree.
 
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> t.insert(50)
         >>> t.insert(75)
         >>> t.insert(25)
         >>> t.insert(50)
         >>> t.insert(75)
         >>> t.insert(25)
@@ -405,6 +461,51 @@ class BinaryTree(object):
         if self.root is not None:
             yield from self._iterate_by_depth(self.root, depth)
 
         if self.root is not None:
             yield from self._iterate_by_depth(self.root, depth)
 
+    def get_next_node(self, node: Node) -> Node:
+        """
+        Given a tree node, get the next greater node in the tree.
+
+        >>> t = BinarySearchTree()
+        >>> t.insert(50)
+        >>> t.insert(75)
+        >>> t.insert(25)
+        >>> t.insert(66)
+        >>> t.insert(22)
+        >>> t.insert(13)
+        >>> t.insert(23)
+        >>> t
+        50
+        ├──25
+        │  └──22
+        │     ├──13
+        │     └──23
+        └──75
+           └──66
+
+        >>> n = t[23]
+        >>> t.get_next_node(n).value
+        25
+
+        >>> n = t[50]
+        >>> t.get_next_node(n).value
+        66
+
+        """
+        if node.right is not None:
+            x = node.right
+            while x.left is not None:
+                x = x.left
+            return x
+
+        path = self.parent_path(node)
+        assert path[-1] == node
+        path = path[:-1]
+        path.reverse()
+        for ancestor in path:
+            if node != ancestor.right:
+                return ancestor
+            node = ancestor
+
     def _depth(self, node: Node, sofar: int) -> int:
         depth_left = sofar + 1
         depth_right = sofar + 1
     def _depth(self, node: Node, sofar: int) -> int:
         depth_left = sofar + 1
         depth_right = sofar + 1
@@ -419,7 +520,7 @@ class BinaryTree(object):
         Returns the max height (depth) of the tree in plies (edge distance
         from root).
 
         Returns the max height (depth) of the tree in plies (edge distance
         from root).
 
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> t.depth()
         0
 
         >>> t.depth()
         0
 
@@ -469,7 +570,7 @@ class BinaryTree(object):
         """
         Draw the tree in ASCII.
 
         """
         Draw the tree in ASCII.
 
-        >>> t = BinaryTree()
+        >>> t = BinarySearchTree()
         >>> t.insert(50)
         >>> t.insert(25)
         >>> t.insert(75)
         >>> t.insert(50)
         >>> t.insert(25)
         >>> t.insert(75)