From 3811faee504f65afcc70bd7b6e922c4f49dbdaf2 Mon Sep 17 00:00:00 2001 From: Scott Gasch Date: Thu, 15 Dec 2022 15:16:21 -0800 Subject: [PATCH] Fix interval_tree so it actually works. Add unittests. --- src/pyutils/collectionz/bst.py | 9 ++- src/pyutils/collectionz/interval_tree.py | 89 +++++++++++++++--------- 2 files changed, 62 insertions(+), 36 deletions(-) diff --git a/src/pyutils/collectionz/bst.py b/src/pyutils/collectionz/bst.py index 1efed52..aaefc1e 100644 --- a/src/pyutils/collectionz/bst.py +++ b/src/pyutils/collectionz/bst.py @@ -272,8 +272,7 @@ class BinarySearchTree(object): return False def _on_delete(self, parent: Optional[Node], deleted: Node) -> None: - """This is called just before deleted is deleted -- - i.e. before the tree changes.""" + """This is called just after deleted was deleted from the tree""" pass def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool: @@ -282,37 +281,37 @@ class BinarySearchTree(object): # Deleting a leaf node if node.left is None and node.right is None: - self._on_delete(parent, node) if parent is not None: if parent.left == node: parent.left = None else: assert parent.right == node parent.right = None + self._on_delete(parent, node) return True # Node only has a right. elif node.left is None: assert node.right is not None - self._on_delete(parent, node) if parent is not None: if parent.left == node: parent.left = node.right else: assert parent.right == node parent.right = node.right + self._on_delete(parent, node) return True # Node only has a left. elif node.right is None: assert node.left is not None - self._on_delete(parent, node) if parent is not None: if parent.left == node: parent.left = node.left else: assert parent.right == node parent.right = node.left + self._on_delete(parent, node) return True # Node has both a left and right. diff --git a/src/pyutils/collectionz/interval_tree.py b/src/pyutils/collectionz/interval_tree.py index 733aea0..c78465c 100644 --- a/src/pyutils/collectionz/interval_tree.py +++ b/src/pyutils/collectionz/interval_tree.py @@ -66,19 +66,48 @@ class AugmentedIntervalTree(bst.BinarySearchTree): @overrides def _on_delete(self, parent: Optional[bst.Node], deleted: bst.Node) -> None: if parent: - new_highest_candidates = [] - if deleted.left: - new_highest_candidates.append(deleted.left.value.highest_in_subtree) - if deleted.right: - new_highest_candidates.append(deleted.right.value.highest_in_subtree) - if len(new_highest_candidates): - parent.value.highest_in_subtree = max( - parent.value.high, max(new_highest_candidates) - ) - else: - parent.value.highest_in_subtree = parent.value.high - - def find_overlaps(self, x: NumericRange): + new_highest_candidates = [parent.value.high] + if parent.left: + new_highest_candidates.append(parent.left.value.highest_in_subtree) + if parent.right: + new_highest_candidates.append(parent.right.value.highest_in_subtree) + parent.value.highest_in_subtree = max(new_highest_candidates) + + def find_one_overlap(self, x: NumericRange): + """Identify and return one overlapping node from the tree. + + >>> tree = AugmentedIntervalTree() + >>> tree.insert(NumericRange(20, 24)) + >>> tree.insert(NumericRange(18, 22)) + >>> tree.insert(NumericRange(14, 16)) + >>> tree.insert(NumericRange(1, 30)) + >>> tree.insert(NumericRange(25, 30)) + >>> tree.insert(NumericRange(29, 33)) + >>> tree.insert(NumericRange(5, 12)) + >>> tree.insert(NumericRange(1, 6)) + >>> tree.insert(NumericRange(13, 18)) + >>> tree.insert(NumericRange(16, 28)) + >>> tree.insert(NumericRange(21, 27)) + >>> tree.find_one_overlap(NumericRange(6, 7)) + 1..30 + """ + return self._find_one_overlap(self.root, x) + + def _find_one_overlap(self, root: bst.Node, x: NumericRange): + if root is None: + return + + if root.value.overlaps_with(x): + return root.value + + if root.left: + if root.left.value.highest_in_subtree >= x.low: + return self._find_one_overlap(root.left, x) + + if root.right: + return self._find_one_overlap(root.right, x) + + def find_all_overlaps(self, x: NumericRange): """Yields ranges previously added to the tree that x overlaps with. >>> tree = AugmentedIntervalTree() @@ -93,29 +122,27 @@ class AugmentedIntervalTree(bst.BinarySearchTree): >>> tree.insert(NumericRange(13, 18)) >>> tree.insert(NumericRange(16, 28)) >>> tree.insert(NumericRange(21, 27)) - >>> for x in tree.find_overlaps(NumericRange(19, 21)): + >>> for x in tree.find_all_overlaps(NumericRange(19, 21)): ... print(x) 20..24 18..22 1..30 16..28 21..27 + + >>> del tree[NumericRange(1, 30)] + >>> for x in tree.find_all_overlaps(NumericRange(19, 21)): + ... print(x) + 20..24 + 18..22 + 16..28 + 21..27 """ if self.root is None: return - yield from self._find_overlaps(self.root, x) - - def _find_overlaps(self, root: bst.Node, x: NumericRange): - """It's known that two intervals A and B overlap only - when both A.low <= B.high and A.high >= B.low. When - searching the trees for nodes overlapping with a given - interval, we can immediately skip: - - * all nodes to the right of nodes whose low value is past - the end of the given interval and - * all nodes that have their maximum high value below the - start of the given interval. - """ + yield from self._find_all_overlaps(self.root, x) + + def _find_all_overlaps(self, root: bst.Node, x: NumericRange): if root is None: return @@ -124,11 +151,11 @@ class AugmentedIntervalTree(bst.BinarySearchTree): if root.left: if root.left.value.highest_in_subtree >= x.low: - yield from self._find_overlaps(root.left, x) + yield from self._find_all_overlaps(root.left, x) - if root.value.low <= x.high: - if root.right: - yield from self._find_overlaps(root.right, x) + if root.right: + if root.right.value.highest_in_subtree >= x.low: + yield from self._find_all_overlaps(root.right, x) if __name__ == "__main__": -- 2.45.2