Adds IntervalTree.
[pyutils.git] / src / pyutils / collectionz / bst.py
index 2e5e3ce95599811aecb553fe5b18e2ebd97c1434..1efed52838cb852259f7881270c208d3fb0f50ad 100644 (file)
@@ -2,16 +2,20 @@
 
 # © Copyright 2021-2022, Scott Gasch
 
-"""A binary search tree."""
+"""A binary search tree implementation."""
 
 from typing import Any, Generator, List, Optional
 
 
 class Node(object):
     def __init__(self, value: Any) -> None:
-        """Note that value can be anything as long as it is
-        comparable.  Check out @functools.total_ordering.
+        """
+        A BST node.  Note that value can be anything as long as it
+        is comparable.  Check out :meth:`functools.total_ordering`
+        (https://docs.python.org/3/library/functools.html#functools.total_ordering)
 
+        Args:
+            value: a reference to the value of the node.
         """
         self.left: Optional[Node] = None
         self.right: Optional[Node] = None
@@ -25,14 +29,24 @@ class BinarySearchTree(object):
         self.traverse = None
 
     def get_root(self) -> Optional[Node]:
-        """:returns the root of the BST."""
+        """
+        Returns:
+            The root of the BST
+        """
 
         return self.root
 
-    def insert(self, value: Any):
+    def _on_insert(self, parent: Optional[Node], new: Node) -> None:
+        """This is called immediately _after_ a new node is inserted."""
+        pass
+
+    def insert(self, value: Any) -> None:
         """
         Insert something into the tree.
 
+        Args:
+            value: the value to be inserted.
+
         >>> t = BinarySearchTree()
         >>> t.insert(10)
         >>> t.insert(20)
@@ -47,6 +61,7 @@ class BinarySearchTree(object):
         if self.root is None:
             self.root = Node(value)
             self.count = 1
+            self._on_insert(None, self.root)
         else:
             self._insert(value, self.root)
 
@@ -58,12 +73,14 @@ class BinarySearchTree(object):
             else:
                 node.left = Node(value)
                 self.count += 1
+                self._on_insert(node, node.left)
         else:
             if node.right is not None:
                 self._insert(value, node.right)
             else:
                 node.right = Node(value)
                 self.count += 1
+                self._on_insert(node, node.right)
 
     def __getitem__(self, value: Any) -> Optional[Node]:
         """
@@ -99,6 +116,7 @@ class BinarySearchTree(object):
     def _parent_path(
         self, current: Optional[Node], target: Node
     ) -> List[Optional[Node]]:
+        """Internal helper"""
         if current is None:
             return [None]
         ret: List[Optional[Node]] = [current]
@@ -113,11 +131,20 @@ class BinarySearchTree(object):
             return ret
 
     def parent_path(self, node: Node) -> List[Optional[Node]]:
-        """Return a list of nodes representing the path from
-        the tree's root to the node argument.  If the node does
-        not exist in the tree for some reason, the last element
-        on the path will be None but the path will indicate the
-        ancestor path of that node were it inserted.
+        """Get a node's parent path.
+
+        Args:
+            node: the node to check
+
+        Returns:
+            a list of nodes representing the path from
+            the tree's root to the node.
+
+        .. note::
+
+            If the node does not exist in the tree, the last element
+            on the path will be None but the path will indicate the
+            ancestor path of that node were it to be inserted.
 
         >>> t = BinarySearchTree()
         >>> t.insert(50)
@@ -162,6 +189,13 @@ class BinarySearchTree(object):
         """
         Delete an item from the tree and preserve the BST property.
 
+        Args:
+            value: the value of the node to be deleted.
+
+        Returns:
+            True if the value was found and its associated node was
+            successfully deleted and False otherwise.
+
         >>> t = BinarySearchTree()
         >>> t.insert(50)
         >>> t.insert(75)
@@ -237,11 +271,18 @@ class BinarySearchTree(object):
             return ret
         return False
 
+    def _on_delete(self, parent: Optional[Node], deleted: Node) -> None:
+        """This is called just before deleted is deleted --
+        i.e. before the tree changes."""
+        pass
+
     def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool:
         """Delete helper"""
         if node.value == value:
+
             # Deleting a leaf node
             if node.left is None and node.right is None:
+                self._on_delete(parent, node)
                 if parent is not None:
                     if parent.left == node:
                         parent.left = None
@@ -253,6 +294,7 @@ class BinarySearchTree(object):
             # Node only has a right.
             elif node.left is None:
                 assert node.right is not None
+                self._on_delete(parent, node)
                 if parent is not None:
                     if parent.left == node:
                         parent.left = node.right
@@ -264,6 +306,7 @@ class BinarySearchTree(object):
             # Node only has a left.
             elif node.right is None:
                 assert node.left is not None
+                self._on_delete(parent, node)
                 if parent is not None:
                     if parent.left == node:
                         parent.left = node.left
@@ -288,7 +331,8 @@ class BinarySearchTree(object):
 
     def __len__(self):
         """
-        Returns the count of items in the tree.
+        Returns:
+            The count of items in the tree.
 
         >>> t = BinarySearchTree()
         >>> len(t)
@@ -314,7 +358,8 @@ class BinarySearchTree(object):
 
     def __contains__(self, value: Any) -> bool:
         """
-        Returns True if the item is in the tree; False otherwise.
+        Returns:
+            True if the item is in the tree; False otherwise.
         """
         return self.__getitem__(value) is not None
 
@@ -341,7 +386,9 @@ class BinarySearchTree(object):
 
     def iterate_preorder(self):
         """
-        Yield the tree's items in a preorder traversal sequence.
+        Returns:
+            A Generator that yields the tree's items in a
+            preorder traversal sequence.
 
         >>> t = BinarySearchTree()
         >>> t.insert(50)
@@ -366,7 +413,9 @@ class BinarySearchTree(object):
 
     def iterate_inorder(self):
         """
-        Yield the tree's items in a preorder traversal sequence.
+        Returns:
+            A Generator that yield the tree's items in a preorder
+            traversal sequence.
 
         >>> t = BinarySearchTree()
         >>> t.insert(50)
@@ -401,7 +450,9 @@ class BinarySearchTree(object):
 
     def iterate_postorder(self):
         """
-        Yield the tree's items in a preorder traversal sequence.
+        Returns:
+            A Generator that yield the tree's items in a preorder
+            traversal sequence.
 
         >>> t = BinarySearchTree()
         >>> t.insert(50)
@@ -434,7 +485,9 @@ class BinarySearchTree(object):
 
     def iterate_leaves(self):
         """
-        Iterate only the leaf nodes in the tree.
+        Returns:
+            A Gemerator that yielde only the leaf nodes in the
+            tree.
 
         >>> t = BinarySearchTree()
         >>> t.insert(50)
@@ -465,7 +518,12 @@ class BinarySearchTree(object):
 
     def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
         """
-        Iterate only the leaf nodes in the tree.
+        Args:
+            depth: the desired depth
+
+        Returns:
+            A Generator that yields nodes at the prescribed depth in
+            the tree.
 
         >>> t = BinarySearchTree()
         >>> t.insert(50)
@@ -490,7 +548,11 @@ class BinarySearchTree(object):
 
     def get_next_node(self, node: Node) -> Node:
         """
-        Given a tree node, get the next greater node in the tree.
+        Args:
+            node: the node whose next greater successor is desired
+
+        Returns:
+            Given a tree node, returns the next greater node in the tree.
 
         >>> t = BinarySearchTree()
         >>> t.insert(50)
@@ -547,8 +609,9 @@ class BinarySearchTree(object):
 
     def depth(self) -> int:
         """
-        Returns the max height (depth) of the tree in plies (edge distance
-        from root).
+        Returns:
+            The max height (depth) of the tree in plies (edge distance
+            from root).
 
         >>> t = BinarySearchTree()
         >>> t.depth()
@@ -588,11 +651,11 @@ class BinarySearchTree(object):
         has_right_sibling: bool,
     ) -> str:
         if node is not None:
-            viz = f'\n{padding}{pointer}{node.value}'
+            viz = f"\n{padding}{pointer}{node.value}"
             if has_right_sibling:
                 padding += "│  "
             else:
-                padding += '   '
+                padding += "   "
 
             pointer_right = "└──"
             if node.right is not None:
@@ -609,7 +672,8 @@ class BinarySearchTree(object):
 
     def __repr__(self):
         """
-        Draw the tree in ASCII.
+        Returns:
+            An ASCII string representation of the tree.
 
         >>> t = BinarySearchTree()
         >>> t.insert(50)
@@ -631,7 +695,7 @@ class BinarySearchTree(object):
         if self.root is None:
             return ""
 
-        ret = f'{self.root.value}'
+        ret = f"{self.root.value}"
         pointer_right = "└──"
         if self.root.right is None:
             pointer_left = "└──"
@@ -639,7 +703,13 @@ class BinarySearchTree(object):
             pointer_left = "├──"
 
         ret += self.repr_traverse(
-            '', pointer_left, self.root.left, self.root.left is not None
+            "", pointer_left, self.root.left, self.root.left is not None
         )
-        ret += self.repr_traverse('', pointer_right, self.root.right, False)
+        ret += self.repr_traverse("", pointer_right, self.root.right, False)
         return ret
+
+
+if __name__ == "__main__":
+    import doctest
+
+    doctest.testmod()