3 """This is an augmented interval tree for storing ranges and identifying overlaps as
4 described by: https://en.wikipedia.org/wiki/Interval_tree.
7 from __future__ import annotations
9 from functools import total_ordering
10 from typing import Any, Generator, Optional, Union
12 from overrides import overrides
14 from pyutils.collectionz import bst
16 Numeric = Union[int, float]
20 class NumericRange(object):
21 """Essentially a tuple of numbers denoting a range with some added
22 helper methods on it."""
24 def __init__(self, low: Numeric, high: Numeric):
25 """Creates a NumericRange.
28 low: the lowest point in the range (inclusive).
29 high: the highest point in the range (inclusive).
33 If low > high this code swaps the parameters and keeps the range
41 self.low: Numeric = low
42 self.high: Numeric = high
43 self.highest_in_subtree: Numeric = high
45 def __lt__(self, other: NumericRange) -> bool:
46 return self.low < other.low
49 def __eq__(self, other: object) -> bool:
50 if not isinstance(other, NumericRange):
52 return self.low == other.low and self.high == other.high
54 def overlaps_with(self, other: NumericRange) -> bool:
55 return self.low <= other.high and self.high >= other.low
57 def __repr__(self) -> str:
58 return f"{self.low}..{self.high}"
61 class AugmentedIntervalTree(bst.BinarySearchTree):
63 def _assert_value_must_be_range(value: Any) -> None:
64 if not isinstance(value, NumericRange):
66 "AugmentedIntervalTree expects to use NumericRanges, see bst for a "
67 + "general purpose tree usable for other types."
71 def _on_insert(self, parent: Optional[bst.Node], new: bst.Node) -> None:
72 AugmentedIntervalTree._assert_value_must_be_range(new.value)
73 for ancestor in self.parent_path(new):
75 if new.value.high > ancestor.value.highest_in_subtree:
76 ancestor.value.highest_in_subtree = new.value.high
79 def _on_delete(self, parent: Optional[bst.Node], deleted: bst.Node) -> None:
81 new_highest_candidates = [parent.value.high]
83 new_highest_candidates.append(parent.left.value.highest_in_subtree)
85 new_highest_candidates.append(parent.right.value.highest_in_subtree)
86 parent.value.highest_in_subtree = max(new_highest_candidates)
88 def find_one_overlap(self, to_find: NumericRange) -> Optional[NumericRange]:
89 """Identify and return one overlapping node from the tree.
92 to_find: the interval with which to find an overlap.
95 An overlapping range from the tree or None if no such
96 ranges exist in the tree at present.
98 >>> tree = AugmentedIntervalTree()
99 >>> tree.insert(NumericRange(20, 24))
100 >>> tree.insert(NumericRange(18, 22))
101 >>> tree.insert(NumericRange(14, 16))
102 >>> tree.insert(NumericRange(1, 30))
103 >>> tree.insert(NumericRange(25, 30))
104 >>> tree.insert(NumericRange(29, 33))
105 >>> tree.insert(NumericRange(5, 12))
106 >>> tree.insert(NumericRange(1, 6))
107 >>> tree.insert(NumericRange(13, 18))
108 >>> tree.insert(NumericRange(16, 28))
109 >>> tree.insert(NumericRange(21, 27))
110 >>> tree.find_one_overlap(NumericRange(6, 7))
114 return self._find_one_overlap(self.root, to_find)
116 def _find_one_overlap(
117 self, root: bst.Node, x: NumericRange
118 ) -> Optional[NumericRange]:
122 if root.value.overlaps_with(x):
126 if root.left.value.highest_in_subtree >= x.low:
127 return self._find_one_overlap(root.left, x)
130 return self._find_one_overlap(root.right, x)
133 def find_all_overlaps(
134 self, to_find: NumericRange
135 ) -> Generator[NumericRange, None, None]:
136 """Yields ranges previously added to the tree that overlaps with
140 to_find: the interval with which to find all overlaps.
143 A (potentially empty) sequence of all ranges in the tree
144 that overlap with the argument.
146 >>> tree = AugmentedIntervalTree()
147 >>> tree.insert(NumericRange(20, 24))
148 >>> tree.insert(NumericRange(18, 22))
149 >>> tree.insert(NumericRange(14, 16))
150 >>> tree.insert(NumericRange(1, 30))
151 >>> tree.insert(NumericRange(25, 30))
152 >>> tree.insert(NumericRange(29, 33))
153 >>> tree.insert(NumericRange(5, 12))
154 >>> tree.insert(NumericRange(1, 6))
155 >>> tree.insert(NumericRange(13, 18))
156 >>> tree.insert(NumericRange(16, 28))
157 >>> tree.insert(NumericRange(21, 27))
158 >>> for x in tree.find_all_overlaps(NumericRange(19, 21)):
166 >>> del tree[NumericRange(1, 30)]
167 >>> for x in tree.find_all_overlaps(NumericRange(19, 21)):
175 if self.root is None:
177 yield from self._find_all_overlaps(self.root, to_find)
179 def _find_all_overlaps(
180 self, root: bst.Node, x: NumericRange
181 ) -> Generator[NumericRange, None, None]:
185 if root.value.overlaps_with(x):
189 if root.left.value.highest_in_subtree >= x.low:
190 yield from self._find_all_overlaps(root.left, x)
193 if root.right.value.highest_in_subtree >= x.low:
194 yield from self._find_all_overlaps(root.right, x)
197 if __name__ == "__main__":