Tiebreak ordering of ranges with the same lower bound using upper bound.
[pyutils.git] / src / pyutils / collectionz / interval_tree.py
index 92c975d35a277f8a2f266c2ce221b3401e8a7c60..a8278a2dc8ea835a501951e3abddb9727d405930 100644 (file)
@@ -22,6 +22,18 @@ class NumericRange(object):
     helper methods on it."""
 
     def __init__(self, low: Numeric, high: Numeric):
+        """Creates a NumericRange.
+
+        Args:
+            low: the lowest point in the range (inclusive).
+            high: the highest point in the range (inclusive).
+
+        .. warning::
+
+            If low > high this code swaps the parameters and keeps the range
+            rather than raising.
+
+        """
         if low > high:
             temp: Numeric = low
             low = high
@@ -31,15 +43,30 @@ class NumericRange(object):
         self.highest_in_subtree: Numeric = high
 
     def __lt__(self, other: NumericRange) -> bool:
-        return self.low < other.low
+        """
+        Returns:
+            True is this range is less than (lower low) other, else False.
+        """
+        if self.low != other.low:
+            return self.low < other.low
+        else:
+            return self.high < other.high
 
     @overrides
     def __eq__(self, other: object) -> bool:
+        """
+        Returns:
+            True if this is the same range as other, else False.
+        """
         if not isinstance(other, NumericRange):
             return False
         return self.low == other.low and self.high == other.high
 
     def overlaps_with(self, other: NumericRange) -> bool:
+        """
+        Returns:
+            True if this NumericRange overlaps with other, else False.
+        """
         return self.low <= other.high and self.high >= other.low
 
     def __repr__(self) -> str:
@@ -73,9 +100,16 @@ class AugmentedIntervalTree(bst.BinarySearchTree):
                 new_highest_candidates.append(parent.right.value.highest_in_subtree)
             parent.value.highest_in_subtree = max(new_highest_candidates)
 
-    def find_one_overlap(self, x: NumericRange):
+    def find_one_overlap(self, to_find: NumericRange) -> Optional[NumericRange]:
         """Identify and return one overlapping node from the tree.
 
+        Args:
+            to_find: the interval with which to find an overlap.
+
+        Returns:
+            An overlapping range from the tree or None if no such
+            ranges exist in the tree at present.
+
         >>> tree = AugmentedIntervalTree()
         >>> tree.insert(NumericRange(20, 24))
         >>> tree.insert(NumericRange(18, 22))
@@ -90,8 +124,9 @@ class AugmentedIntervalTree(bst.BinarySearchTree):
         >>> tree.insert(NumericRange(21, 27))
         >>> tree.find_one_overlap(NumericRange(6, 7))
         1..30
+
         """
-        return self._find_one_overlap(self.root, x)
+        return self._find_one_overlap(self.root, to_find)
 
     def _find_one_overlap(
         self, root: bst.Node, x: NumericRange
@@ -110,8 +145,18 @@ class AugmentedIntervalTree(bst.BinarySearchTree):
             return self._find_one_overlap(root.right, x)
         return None
 
-    def find_all_overlaps(self, x: NumericRange):
-        """Yields ranges previously added to the tree that x overlaps with.
+    def find_all_overlaps(
+        self, to_find: NumericRange
+    ) -> Generator[NumericRange, None, None]:
+        """Yields ranges previously added to the tree that overlaps with
+        to_find argument.
+
+        Args:
+            to_find: the interval with which to find all overlaps.
+
+        Returns:
+            A (potentially empty) sequence of all ranges in the tree
+            that overlap with the argument.
 
         >>> tree = AugmentedIntervalTree()
         >>> tree.insert(NumericRange(20, 24))
@@ -140,10 +185,11 @@ class AugmentedIntervalTree(bst.BinarySearchTree):
         18..22
         16..28
         21..27
+
         """
         if self.root is None:
             return
-        yield from self._find_all_overlaps(self.root, x)
+        yield from self._find_all_overlaps(self.root, to_find)
 
     def _find_all_overlaps(
         self, root: bst.Node, x: NumericRange