Fix interval_tree so it actually works. Add unittests.
authorScott Gasch <[email protected]>
Thu, 15 Dec 2022 23:16:21 +0000 (15:16 -0800)
committerScott Gasch <[email protected]>
Thu, 15 Dec 2022 23:16:21 +0000 (15:16 -0800)
src/pyutils/collectionz/bst.py
src/pyutils/collectionz/interval_tree.py

index 1efed52838cb852259f7881270c208d3fb0f50ad..aaefc1e52a065f3f38a088f0f91089ae75c8c228 100644 (file)
@@ -272,8 +272,7 @@ class BinarySearchTree(object):
         return False
 
     def _on_delete(self, parent: Optional[Node], deleted: Node) -> None:
-        """This is called just before deleted is deleted --
-        i.e. before the tree changes."""
+        """This is called just after deleted was deleted from the tree"""
         pass
 
     def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool:
@@ -282,37 +281,37 @@ class BinarySearchTree(object):
 
             # Deleting a leaf node
             if node.left is None and node.right is None:
-                self._on_delete(parent, node)
                 if parent is not None:
                     if parent.left == node:
                         parent.left = None
                     else:
                         assert parent.right == node
                         parent.right = None
+                self._on_delete(parent, node)
                 return True
 
             # Node only has a right.
             elif node.left is None:
                 assert node.right is not None
-                self._on_delete(parent, node)
                 if parent is not None:
                     if parent.left == node:
                         parent.left = node.right
                     else:
                         assert parent.right == node
                         parent.right = node.right
+                self._on_delete(parent, node)
                 return True
 
             # Node only has a left.
             elif node.right is None:
                 assert node.left is not None
-                self._on_delete(parent, node)
                 if parent is not None:
                     if parent.left == node:
                         parent.left = node.left
                     else:
                         assert parent.right == node
                         parent.right = node.left
+                self._on_delete(parent, node)
                 return True
 
             # Node has both a left and right.
index 733aea07801c68e8762da37d7b7d6067f84be9f0..c78465c65fdf72c581df22c17c057330bdce0113 100644 (file)
@@ -66,19 +66,48 @@ class AugmentedIntervalTree(bst.BinarySearchTree):
     @overrides
     def _on_delete(self, parent: Optional[bst.Node], deleted: bst.Node) -> None:
         if parent:
-            new_highest_candidates = []
-            if deleted.left:
-                new_highest_candidates.append(deleted.left.value.highest_in_subtree)
-            if deleted.right:
-                new_highest_candidates.append(deleted.right.value.highest_in_subtree)
-            if len(new_highest_candidates):
-                parent.value.highest_in_subtree = max(
-                    parent.value.high, max(new_highest_candidates)
-                )
-            else:
-                parent.value.highest_in_subtree = parent.value.high
-
-    def find_overlaps(self, x: NumericRange):
+            new_highest_candidates = [parent.value.high]
+            if parent.left:
+                new_highest_candidates.append(parent.left.value.highest_in_subtree)
+            if parent.right:
+                new_highest_candidates.append(parent.right.value.highest_in_subtree)
+            parent.value.highest_in_subtree = max(new_highest_candidates)
+
+    def find_one_overlap(self, x: NumericRange):
+        """Identify and return one overlapping node from the tree.
+
+        >>> tree = AugmentedIntervalTree()
+        >>> tree.insert(NumericRange(20, 24))
+        >>> tree.insert(NumericRange(18, 22))
+        >>> tree.insert(NumericRange(14, 16))
+        >>> tree.insert(NumericRange(1, 30))
+        >>> tree.insert(NumericRange(25, 30))
+        >>> tree.insert(NumericRange(29, 33))
+        >>> tree.insert(NumericRange(5, 12))
+        >>> tree.insert(NumericRange(1, 6))
+        >>> tree.insert(NumericRange(13, 18))
+        >>> tree.insert(NumericRange(16, 28))
+        >>> tree.insert(NumericRange(21, 27))
+        >>> tree.find_one_overlap(NumericRange(6, 7))
+        1..30
+        """
+        return self._find_one_overlap(self.root, x)
+
+    def _find_one_overlap(self, root: bst.Node, x: NumericRange):
+        if root is None:
+            return
+
+        if root.value.overlaps_with(x):
+            return root.value
+
+        if root.left:
+            if root.left.value.highest_in_subtree >= x.low:
+                return self._find_one_overlap(root.left, x)
+
+        if root.right:
+            return self._find_one_overlap(root.right, x)
+
+    def find_all_overlaps(self, x: NumericRange):
         """Yields ranges previously added to the tree that x overlaps with.
 
         >>> tree = AugmentedIntervalTree()
@@ -93,29 +122,27 @@ class AugmentedIntervalTree(bst.BinarySearchTree):
         >>> tree.insert(NumericRange(13, 18))
         >>> tree.insert(NumericRange(16, 28))
         >>> tree.insert(NumericRange(21, 27))
-        >>> for x in tree.find_overlaps(NumericRange(19, 21)):
+        >>> for x in tree.find_all_overlaps(NumericRange(19, 21)):
         ...     print(x)
         20..24
         18..22
         1..30
         16..28
         21..27
+
+        >>> del tree[NumericRange(1, 30)]
+        >>> for x in tree.find_all_overlaps(NumericRange(19, 21)):
+        ...     print(x)
+        20..24
+        18..22
+        16..28
+        21..27
         """
         if self.root is None:
             return
-        yield from self._find_overlaps(self.root, x)
-
-    def _find_overlaps(self, root: bst.Node, x: NumericRange):
-        """It's known that two intervals A and B overlap only
-        when both A.low <= B.high and A.high >= B.low.  When
-        searching the trees for nodes overlapping with a given
-        interval, we can immediately skip:
-
-            * all nodes to the right of nodes whose low value is past
-              the end of the given interval and
-            * all nodes that have their maximum high value below the
-              start of the given interval.
-        """
+        yield from self._find_all_overlaps(self.root, x)
+
+    def _find_all_overlaps(self, root: bst.Node, x: NumericRange):
         if root is None:
             return
 
@@ -124,11 +151,11 @@ class AugmentedIntervalTree(bst.BinarySearchTree):
 
         if root.left:
             if root.left.value.highest_in_subtree >= x.low:
-                yield from self._find_overlaps(root.left, x)
+                yield from self._find_all_overlaps(root.left, x)
 
-        if root.value.low <= x.high:
-            if root.right:
-                yield from self._find_overlaps(root.right, x)
+        if root.right:
+            if root.right.value.highest_in_subtree >= x.low:
+                yield from self._find_all_overlaps(root.right, x)
 
 
 if __name__ == "__main__":