Adds IntervalTree.
authorScott Gasch <[email protected]>
Tue, 13 Dec 2022 22:36:23 +0000 (14:36 -0800)
committerScott Gasch <[email protected]>
Tue, 13 Dec 2022 22:36:23 +0000 (14:36 -0800)
src/pyutils/collectionz/bst.py
src/pyutils/collectionz/interval_tree.py [new file with mode: 0644]

index 4c0bacdd051374a3f700ceba33b4beaad143b956..1efed52838cb852259f7881270c208d3fb0f50ad 100644 (file)
@@ -36,6 +36,10 @@ class BinarySearchTree(object):
 
         return self.root
 
+    def _on_insert(self, parent: Optional[Node], new: Node) -> None:
+        """This is called immediately _after_ a new node is inserted."""
+        pass
+
     def insert(self, value: Any) -> None:
         """
         Insert something into the tree.
@@ -57,6 +61,7 @@ class BinarySearchTree(object):
         if self.root is None:
             self.root = Node(value)
             self.count = 1
+            self._on_insert(None, self.root)
         else:
             self._insert(value, self.root)
 
@@ -68,12 +73,14 @@ class BinarySearchTree(object):
             else:
                 node.left = Node(value)
                 self.count += 1
+                self._on_insert(node, node.left)
         else:
             if node.right is not None:
                 self._insert(value, node.right)
             else:
                 node.right = Node(value)
                 self.count += 1
+                self._on_insert(node, node.right)
 
     def __getitem__(self, value: Any) -> Optional[Node]:
         """
@@ -264,11 +271,18 @@ class BinarySearchTree(object):
             return ret
         return False
 
+    def _on_delete(self, parent: Optional[Node], deleted: Node) -> None:
+        """This is called just before deleted is deleted --
+        i.e. before the tree changes."""
+        pass
+
     def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool:
         """Delete helper"""
         if node.value == value:
+
             # Deleting a leaf node
             if node.left is None and node.right is None:
+                self._on_delete(parent, node)
                 if parent is not None:
                     if parent.left == node:
                         parent.left = None
@@ -280,6 +294,7 @@ class BinarySearchTree(object):
             # Node only has a right.
             elif node.left is None:
                 assert node.right is not None
+                self._on_delete(parent, node)
                 if parent is not None:
                     if parent.left == node:
                         parent.left = node.right
@@ -291,6 +306,7 @@ class BinarySearchTree(object):
             # Node only has a left.
             elif node.right is None:
                 assert node.left is not None
+                self._on_delete(parent, node)
                 if parent is not None:
                     if parent.left == node:
                         parent.left = node.left
@@ -635,11 +651,11 @@ class BinarySearchTree(object):
         has_right_sibling: bool,
     ) -> str:
         if node is not None:
-            viz = f'\n{padding}{pointer}{node.value}'
+            viz = f"\n{padding}{pointer}{node.value}"
             if has_right_sibling:
                 padding += "│  "
             else:
-                padding += '   '
+                padding += "   "
 
             pointer_right = "└──"
             if node.right is not None:
@@ -679,7 +695,7 @@ class BinarySearchTree(object):
         if self.root is None:
             return ""
 
-        ret = f'{self.root.value}'
+        ret = f"{self.root.value}"
         pointer_right = "└──"
         if self.root.right is None:
             pointer_left = "└──"
@@ -687,13 +703,13 @@ class BinarySearchTree(object):
             pointer_left = "├──"
 
         ret += self.repr_traverse(
-            '', pointer_left, self.root.left, self.root.left is not None
+            "", pointer_left, self.root.left, self.root.left is not None
         )
-        ret += self.repr_traverse('', pointer_right, self.root.right, False)
+        ret += self.repr_traverse("", pointer_right, self.root.right, False)
         return ret
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     import doctest
 
     doctest.testmod()
diff --git a/src/pyutils/collectionz/interval_tree.py b/src/pyutils/collectionz/interval_tree.py
new file mode 100644 (file)
index 0000000..733aea0
--- /dev/null
@@ -0,0 +1,137 @@
+#!/usr/bin/env python3
+
+"""This is an augmented interval tree for storing ranges and identifying overlaps as
+described by: https://en.wikipedia.org/wiki/Interval_tree.
+"""
+
+from __future__ import annotations
+
+from functools import total_ordering
+from typing import Any, Optional, Union
+
+from overrides import overrides
+
+from pyutils.collectionz import bst
+
+Numeric = Union[int, float]
+
+
+@total_ordering
+class NumericRange(object):
+    def __init__(self, low: Numeric, high: Numeric):
+        if low > high:
+            temp: Numeric = low
+            low = high
+            high = temp
+        self.low: Numeric = low
+        self.high: Numeric = high
+        self.highest_in_subtree: Numeric = high
+
+    def __lt__(self, other: NumericRange) -> bool:
+        return self.low < other.low
+
+    @overrides
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, NumericRange):
+            return False
+        return self.low == other.low and self.high == other.high
+
+    def overlaps_with(self, other: NumericRange) -> bool:
+        return self.low <= other.high and self.high >= other.low
+
+    def __repr__(self) -> str:
+        return f"{self.low}..{self.high}"
+
+
+class AugmentedIntervalTree(bst.BinarySearchTree):
+    def __init__(self):
+        super().__init__()
+
+    @staticmethod
+    def assert_value_must_be_range(value: Any) -> None:
+        if not isinstance(value, NumericRange):
+            raise Exception(
+                "AugmentedIntervalTree expects to use NumericRanges, see bst for a "
+                + "general purpose tree usable for other types."
+            )
+
+    @overrides
+    def _on_insert(self, parent: Optional[bst.Node], new: bst.Node) -> None:
+        AugmentedIntervalTree.assert_value_must_be_range(new.value)
+        for ancestor in self.parent_path(new):
+            assert ancestor
+            if new.value.high > ancestor.value.highest_in_subtree:
+                ancestor.value.highest_in_subtree = new.value.high
+
+    @overrides
+    def _on_delete(self, parent: Optional[bst.Node], deleted: bst.Node) -> None:
+        if parent:
+            new_highest_candidates = []
+            if deleted.left:
+                new_highest_candidates.append(deleted.left.value.highest_in_subtree)
+            if deleted.right:
+                new_highest_candidates.append(deleted.right.value.highest_in_subtree)
+            if len(new_highest_candidates):
+                parent.value.highest_in_subtree = max(
+                    parent.value.high, max(new_highest_candidates)
+                )
+            else:
+                parent.value.highest_in_subtree = parent.value.high
+
+    def find_overlaps(self, x: NumericRange):
+        """Yields ranges previously added to the tree that x overlaps with.
+
+        >>> tree = AugmentedIntervalTree()
+        >>> tree.insert(NumericRange(20, 24))
+        >>> tree.insert(NumericRange(18, 22))
+        >>> tree.insert(NumericRange(14, 16))
+        >>> tree.insert(NumericRange(1, 30))
+        >>> tree.insert(NumericRange(25, 30))
+        >>> tree.insert(NumericRange(29, 33))
+        >>> tree.insert(NumericRange(5, 12))
+        >>> tree.insert(NumericRange(1, 6))
+        >>> tree.insert(NumericRange(13, 18))
+        >>> tree.insert(NumericRange(16, 28))
+        >>> tree.insert(NumericRange(21, 27))
+        >>> for x in tree.find_overlaps(NumericRange(19, 21)):
+        ...     print(x)
+        20..24
+        18..22
+        1..30
+        16..28
+        21..27
+        """
+        if self.root is None:
+            return
+        yield from self._find_overlaps(self.root, x)
+
+    def _find_overlaps(self, root: bst.Node, x: NumericRange):
+        """It's known that two intervals A and B overlap only
+        when both A.low <= B.high and A.high >= B.low.  When
+        searching the trees for nodes overlapping with a given
+        interval, we can immediately skip:
+
+            * all nodes to the right of nodes whose low value is past
+              the end of the given interval and
+            * all nodes that have their maximum high value below the
+              start of the given interval.
+        """
+        if root is None:
+            return
+
+        if root.value.overlaps_with(x):
+            yield root.value
+
+        if root.left:
+            if root.left.value.highest_in_subtree >= x.low:
+                yield from self._find_overlaps(root.left, x)
+
+        if root.value.low <= x.high:
+            if root.right:
+                yield from self._find_overlaps(root.right, x)
+
+
+if __name__ == "__main__":
+    import doctest
+
+    doctest.testmod()