More docs...
[pyutils.git] / src / pyutils / collectionz / interval_tree.py
index 733aea07801c68e8762da37d7b7d6067f84be9f0..7ba190a5984814eb35ff7058181cdc3a79342dd6 100644 (file)
@@ -7,7 +7,7 @@ described by: https://en.wikipedia.org/wiki/Interval_tree.
 from __future__ import annotations
 
 from functools import total_ordering
-from typing import Any, Optional, Union
+from typing import Any, Generator, Optional, Union
 
 from overrides import overrides
 
@@ -18,7 +18,22 @@ Numeric = Union[int, float]
 
 @total_ordering
 class NumericRange(object):
+    """Essentially a tuple of numbers denoting a range with some added
+    helper methods on it."""
+
     def __init__(self, low: Numeric, high: Numeric):
+        """Creates a NumericRange.
+
+        Args:
+            low: the lowest point in the range (inclusive).
+            high: the highest point in the range (inclusive).
+
+        .. warning::
+
+            If low > high this code swaps the parameters and keeps the range
+            rather than raising.
+
+        """
         if low > high:
             temp: Numeric = low
             low = high
@@ -28,15 +43,27 @@ class NumericRange(object):
         self.highest_in_subtree: Numeric = high
 
     def __lt__(self, other: NumericRange) -> bool:
+        """
+        Returns:
+            True is this range is less than (lower low) other, else False.
+        """
         return self.low < other.low
 
     @overrides
     def __eq__(self, other: object) -> bool:
+        """
+        Returns:
+            True if this is the same range as other, else False.
+        """
         if not isinstance(other, NumericRange):
             return False
         return self.low == other.low and self.high == other.high
 
     def overlaps_with(self, other: NumericRange) -> bool:
+        """
+        Returns:
+            True if this NumericRange overlaps with other, else False.
+        """
         return self.low <= other.high and self.high >= other.low
 
     def __repr__(self) -> str:
@@ -44,11 +71,8 @@ class NumericRange(object):
 
 
 class AugmentedIntervalTree(bst.BinarySearchTree):
-    def __init__(self):
-        super().__init__()
-
     @staticmethod
-    def assert_value_must_be_range(value: Any) -> None:
+    def _assert_value_must_be_range(value: Any) -> None:
         if not isinstance(value, NumericRange):
             raise Exception(
                 "AugmentedIntervalTree expects to use NumericRanges, see bst for a "
@@ -57,7 +81,7 @@ class AugmentedIntervalTree(bst.BinarySearchTree):
 
     @overrides
     def _on_insert(self, parent: Optional[bst.Node], new: bst.Node) -> None:
-        AugmentedIntervalTree.assert_value_must_be_range(new.value)
+        AugmentedIntervalTree._assert_value_must_be_range(new.value)
         for ancestor in self.parent_path(new):
             assert ancestor
             if new.value.high > ancestor.value.highest_in_subtree:
@@ -66,20 +90,22 @@ class AugmentedIntervalTree(bst.BinarySearchTree):
     @overrides
     def _on_delete(self, parent: Optional[bst.Node], deleted: bst.Node) -> None:
         if parent:
-            new_highest_candidates = []
-            if deleted.left:
-                new_highest_candidates.append(deleted.left.value.highest_in_subtree)
-            if deleted.right:
-                new_highest_candidates.append(deleted.right.value.highest_in_subtree)
-            if len(new_highest_candidates):
-                parent.value.highest_in_subtree = max(
-                    parent.value.high, max(new_highest_candidates)
-                )
-            else:
-                parent.value.highest_in_subtree = parent.value.high
-
-    def find_overlaps(self, x: NumericRange):
-        """Yields ranges previously added to the tree that x overlaps with.
+            new_highest_candidates = [parent.value.high]
+            if parent.left:
+                new_highest_candidates.append(parent.left.value.highest_in_subtree)
+            if parent.right:
+                new_highest_candidates.append(parent.right.value.highest_in_subtree)
+            parent.value.highest_in_subtree = max(new_highest_candidates)
+
+    def find_one_overlap(self, to_find: NumericRange) -> Optional[NumericRange]:
+        """Identify and return one overlapping node from the tree.
+
+        Args:
+            to_find: the interval with which to find an overlap.
+
+        Returns:
+            An overlapping range from the tree or None if no such
+            ranges exist in the tree at present.
 
         >>> tree = AugmentedIntervalTree()
         >>> tree.insert(NumericRange(20, 24))
@@ -93,42 +119,91 @@ class AugmentedIntervalTree(bst.BinarySearchTree):
         >>> tree.insert(NumericRange(13, 18))
         >>> tree.insert(NumericRange(16, 28))
         >>> tree.insert(NumericRange(21, 27))
-        >>> for x in tree.find_overlaps(NumericRange(19, 21)):
+        >>> tree.find_one_overlap(NumericRange(6, 7))
+        1..30
+
+        """
+        return self._find_one_overlap(self.root, to_find)
+
+    def _find_one_overlap(
+        self, root: bst.Node, x: NumericRange
+    ) -> Optional[NumericRange]:
+        if root is None:
+            return None
+
+        if root.value.overlaps_with(x):
+            return root.value
+
+        if root.left:
+            if root.left.value.highest_in_subtree >= x.low:
+                return self._find_one_overlap(root.left, x)
+
+        if root.right:
+            return self._find_one_overlap(root.right, x)
+        return None
+
+    def find_all_overlaps(
+        self, to_find: NumericRange
+    ) -> Generator[NumericRange, None, None]:
+        """Yields ranges previously added to the tree that overlaps with
+        to_find argument.
+
+        Args:
+            to_find: the interval with which to find all overlaps.
+
+        Returns:
+            A (potentially empty) sequence of all ranges in the tree
+            that overlap with the argument.
+
+        >>> tree = AugmentedIntervalTree()
+        >>> tree.insert(NumericRange(20, 24))
+        >>> tree.insert(NumericRange(18, 22))
+        >>> tree.insert(NumericRange(14, 16))
+        >>> tree.insert(NumericRange(1, 30))
+        >>> tree.insert(NumericRange(25, 30))
+        >>> tree.insert(NumericRange(29, 33))
+        >>> tree.insert(NumericRange(5, 12))
+        >>> tree.insert(NumericRange(1, 6))
+        >>> tree.insert(NumericRange(13, 18))
+        >>> tree.insert(NumericRange(16, 28))
+        >>> tree.insert(NumericRange(21, 27))
+        >>> for x in tree.find_all_overlaps(NumericRange(19, 21)):
         ...     print(x)
         20..24
         18..22
         1..30
         16..28
         21..27
+
+        >>> del tree[NumericRange(1, 30)]
+        >>> for x in tree.find_all_overlaps(NumericRange(19, 21)):
+        ...     print(x)
+        20..24
+        18..22
+        16..28
+        21..27
+
         """
         if self.root is None:
             return
-        yield from self._find_overlaps(self.root, x)
-
-    def _find_overlaps(self, root: bst.Node, x: NumericRange):
-        """It's known that two intervals A and B overlap only
-        when both A.low <= B.high and A.high >= B.low.  When
-        searching the trees for nodes overlapping with a given
-        interval, we can immediately skip:
-
-            * all nodes to the right of nodes whose low value is past
-              the end of the given interval and
-            * all nodes that have their maximum high value below the
-              start of the given interval.
-        """
+        yield from self._find_all_overlaps(self.root, to_find)
+
+    def _find_all_overlaps(
+        self, root: bst.Node, x: NumericRange
+    ) -> Generator[NumericRange, None, None]:
         if root is None:
-            return
+            return None
 
         if root.value.overlaps_with(x):
             yield root.value
 
         if root.left:
             if root.left.value.highest_in_subtree >= x.low:
-                yield from self._find_overlaps(root.left, x)
+                yield from self._find_all_overlaps(root.left, x)
 
-        if root.value.low <= x.high:
-            if root.right:
-                yield from self._find_overlaps(root.right, x)
+        if root.right:
+            if root.right.value.highest_in_subtree >= x.low:
+                yield from self._find_all_overlaps(root.right, x)
 
 
 if __name__ == "__main__":