Adds IntervalTree.
[pyutils.git] / src / pyutils / collectionz / bst.py
index 4c0bacdd051374a3f700ceba33b4beaad143b956..1efed52838cb852259f7881270c208d3fb0f50ad 100644 (file)
@@ -36,6 +36,10 @@ class BinarySearchTree(object):
 
         return self.root
 
+    def _on_insert(self, parent: Optional[Node], new: Node) -> None:
+        """This is called immediately _after_ a new node is inserted."""
+        pass
+
     def insert(self, value: Any) -> None:
         """
         Insert something into the tree.
@@ -57,6 +61,7 @@ class BinarySearchTree(object):
         if self.root is None:
             self.root = Node(value)
             self.count = 1
+            self._on_insert(None, self.root)
         else:
             self._insert(value, self.root)
 
@@ -68,12 +73,14 @@ class BinarySearchTree(object):
             else:
                 node.left = Node(value)
                 self.count += 1
+                self._on_insert(node, node.left)
         else:
             if node.right is not None:
                 self._insert(value, node.right)
             else:
                 node.right = Node(value)
                 self.count += 1
+                self._on_insert(node, node.right)
 
     def __getitem__(self, value: Any) -> Optional[Node]:
         """
@@ -264,11 +271,18 @@ class BinarySearchTree(object):
             return ret
         return False
 
+    def _on_delete(self, parent: Optional[Node], deleted: Node) -> None:
+        """This is called just before deleted is deleted --
+        i.e. before the tree changes."""
+        pass
+
     def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool:
         """Delete helper"""
         if node.value == value:
+
             # Deleting a leaf node
             if node.left is None and node.right is None:
+                self._on_delete(parent, node)
                 if parent is not None:
                     if parent.left == node:
                         parent.left = None
@@ -280,6 +294,7 @@ class BinarySearchTree(object):
             # Node only has a right.
             elif node.left is None:
                 assert node.right is not None
+                self._on_delete(parent, node)
                 if parent is not None:
                     if parent.left == node:
                         parent.left = node.right
@@ -291,6 +306,7 @@ class BinarySearchTree(object):
             # Node only has a left.
             elif node.right is None:
                 assert node.left is not None
+                self._on_delete(parent, node)
                 if parent is not None:
                     if parent.left == node:
                         parent.left = node.left
@@ -635,11 +651,11 @@ class BinarySearchTree(object):
         has_right_sibling: bool,
     ) -> str:
         if node is not None:
-            viz = f'\n{padding}{pointer}{node.value}'
+            viz = f"\n{padding}{pointer}{node.value}"
             if has_right_sibling:
                 padding += "│  "
             else:
-                padding += '   '
+                padding += "   "
 
             pointer_right = "└──"
             if node.right is not None:
@@ -679,7 +695,7 @@ class BinarySearchTree(object):
         if self.root is None:
             return ""
 
-        ret = f'{self.root.value}'
+        ret = f"{self.root.value}"
         pointer_right = "└──"
         if self.root.right is None:
             pointer_left = "└──"
@@ -687,13 +703,13 @@ class BinarySearchTree(object):
             pointer_left = "├──"
 
         ret += self.repr_traverse(
-            '', pointer_left, self.root.left, self.root.left is not None
+            "", pointer_left, self.root.left, self.root.left is not None
         )
-        ret += self.repr_traverse('', pointer_right, self.root.right, False)
+        ret += self.repr_traverse("", pointer_right, self.root.right, False)
         return ret
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     import doctest
 
     doctest.testmod()