Fix interval_tree so it actually works. Add unittests.
[pyutils.git] / src / pyutils / collectionz / bst.py
index 4c0bacdd051374a3f700ceba33b4beaad143b956..aaefc1e52a065f3f38a088f0f91089ae75c8c228 100644 (file)
@@ -36,6 +36,10 @@ class BinarySearchTree(object):
 
         return self.root
 
+    def _on_insert(self, parent: Optional[Node], new: Node) -> None:
+        """This is called immediately _after_ a new node is inserted."""
+        pass
+
     def insert(self, value: Any) -> None:
         """
         Insert something into the tree.
@@ -57,6 +61,7 @@ class BinarySearchTree(object):
         if self.root is None:
             self.root = Node(value)
             self.count = 1
+            self._on_insert(None, self.root)
         else:
             self._insert(value, self.root)
 
@@ -68,12 +73,14 @@ class BinarySearchTree(object):
             else:
                 node.left = Node(value)
                 self.count += 1
+                self._on_insert(node, node.left)
         else:
             if node.right is not None:
                 self._insert(value, node.right)
             else:
                 node.right = Node(value)
                 self.count += 1
+                self._on_insert(node, node.right)
 
     def __getitem__(self, value: Any) -> Optional[Node]:
         """
@@ -264,9 +271,14 @@ class BinarySearchTree(object):
             return ret
         return False
 
+    def _on_delete(self, parent: Optional[Node], deleted: Node) -> None:
+        """This is called just after deleted was deleted from the tree"""
+        pass
+
     def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool:
         """Delete helper"""
         if node.value == value:
+
             # Deleting a leaf node
             if node.left is None and node.right is None:
                 if parent is not None:
@@ -275,6 +287,7 @@ class BinarySearchTree(object):
                     else:
                         assert parent.right == node
                         parent.right = None
+                self._on_delete(parent, node)
                 return True
 
             # Node only has a right.
@@ -286,6 +299,7 @@ class BinarySearchTree(object):
                     else:
                         assert parent.right == node
                         parent.right = node.right
+                self._on_delete(parent, node)
                 return True
 
             # Node only has a left.
@@ -297,6 +311,7 @@ class BinarySearchTree(object):
                     else:
                         assert parent.right == node
                         parent.right = node.left
+                self._on_delete(parent, node)
                 return True
 
             # Node has both a left and right.
@@ -635,11 +650,11 @@ class BinarySearchTree(object):
         has_right_sibling: bool,
     ) -> str:
         if node is not None:
-            viz = f'\n{padding}{pointer}{node.value}'
+            viz = f"\n{padding}{pointer}{node.value}"
             if has_right_sibling:
                 padding += "│  "
             else:
-                padding += '   '
+                padding += "   "
 
             pointer_right = "└──"
             if node.right is not None:
@@ -679,7 +694,7 @@ class BinarySearchTree(object):
         if self.root is None:
             return ""
 
-        ret = f'{self.root.value}'
+        ret = f"{self.root.value}"
         pointer_right = "└──"
         if self.root.right is None:
             pointer_left = "└──"
@@ -687,13 +702,13 @@ class BinarySearchTree(object):
             pointer_left = "├──"
 
         ret += self.repr_traverse(
-            '', pointer_left, self.root.left, self.root.left is not None
+            "", pointer_left, self.root.left, self.root.left is not None
         )
-        ret += self.repr_traverse('', pointer_right, self.root.right, False)
+        ret += self.repr_traverse("", pointer_right, self.root.right, False)
         return ret
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     import doctest
 
     doctest.testmod()