3 # © Copyright 2021-2023, Scott Gasch
5 """This is an augmented interval tree for storing ranges and identifying overlaps as
6 described by: https://en.wikipedia.org/wiki/Interval_tree.
9 from __future__ import annotations
11 from functools import total_ordering
12 from typing import Any, Generator, Optional
14 from overrides import overrides
16 from pyutils.collectionz import bst
17 from pyutils.typez.simple import Numeric
21 class NumericRange(bst.Comparable):
22 """Essentially a tuple of numbers denoting a range with some added
23 helper methods on it."""
25 def __init__(self, low: Numeric, high: Numeric):
26 """Creates a NumericRange.
29 low: the lowest point in the range (inclusive).
30 high: the highest point in the range (inclusive).
34 If low > high this code swaps the parameters and keeps the range
40 self.low: Numeric = low
41 self.high: Numeric = high
42 self.highest_in_subtree: Numeric = high
45 def __lt__(self, other: NumericRange) -> bool:
48 True is this range is less than (lower low) other, else False.
50 if self.low != other.low:
51 return self.low < other.low
53 return self.high < other.high
56 def __eq__(self, other: object) -> bool:
59 True if this is the same range as other, else False.
61 if not isinstance(other, NumericRange):
63 return self.low == other.low and self.high == other.high
66 def __le__(self, other: object) -> bool:
67 if not isinstance(other, NumericRange):
69 return self < other or self == other
71 def overlaps_with(self, other: NumericRange) -> bool:
74 True if this NumericRange overlaps with other, else False.
76 return self.low <= other.high and self.high >= other.low
78 def __repr__(self) -> str:
79 return f"[{self.low}..{self.high}]"
82 class AugmentedIntervalTree(bst.BinarySearchTree):
84 def _assert_value_must_be_range(value: Any) -> NumericRange:
85 if not isinstance(value, NumericRange):
87 "AugmentedIntervalTree expects to use NumericRanges, see bst for a "
88 + "general purpose tree usable for other types."
93 def _on_insert(self, parent: Optional[bst.Node], new: bst.Node) -> None:
94 nv: NumericRange = AugmentedIntervalTree._assert_value_must_be_range(new.value)
95 for ancestor in self.parent_path(new):
97 av: NumericRange = AugmentedIntervalTree._assert_value_must_be_range(
100 if nv.high > av.highest_in_subtree:
101 av.highest_in_subtree = nv.high
104 def _on_delete(self, parent: Optional[bst.Node], deleted: bst.Node) -> None:
106 pv: NumericRange = AugmentedIntervalTree._assert_value_must_be_range(
109 new_highest_candidates = [pv.high]
111 lv: NumericRange = AugmentedIntervalTree._assert_value_must_be_range(
114 new_highest_candidates.append(lv.highest_in_subtree)
116 rv: NumericRange = AugmentedIntervalTree._assert_value_must_be_range(
119 new_highest_candidates.append(rv.highest_in_subtree)
120 pv.highest_in_subtree = max(new_highest_candidates)
122 def find_one_overlap(self, to_find: NumericRange) -> Optional[NumericRange]:
123 """Identify and return one overlapping node from the tree.
126 to_find: the interval with which to find an overlap.
129 An overlapping range from the tree or None if no such
130 ranges exist in the tree at present.
132 >>> tree = AugmentedIntervalTree()
133 >>> tree.insert(NumericRange(20, 24))
134 >>> tree.insert(NumericRange(18, 22))
135 >>> tree.insert(NumericRange(14, 16))
136 >>> tree.insert(NumericRange(1, 30))
137 >>> tree.insert(NumericRange(25, 30))
138 >>> tree.insert(NumericRange(29, 33))
139 >>> tree.insert(NumericRange(5, 12))
140 >>> tree.insert(NumericRange(1, 6))
141 >>> tree.insert(NumericRange(13, 18))
142 >>> tree.insert(NumericRange(16, 28))
143 >>> tree.insert(NumericRange(21, 27))
144 >>> tree.find_one_overlap(NumericRange(6, 7))
148 return self._find_one_overlap(self.root, to_find)
150 def _find_one_overlap(
151 self, root: bst.Node, x: NumericRange
152 ) -> Optional[NumericRange]:
156 rv = AugmentedIntervalTree._assert_value_must_be_range(root.value)
157 if rv.overlaps_with(x):
161 lv = AugmentedIntervalTree._assert_value_must_be_range(root.left.value)
162 if lv.highest_in_subtree >= x.low:
163 return self._find_one_overlap(root.left, x)
166 return self._find_one_overlap(root.right, x)
169 def find_all_overlaps(
170 self, to_find: NumericRange
171 ) -> Generator[NumericRange, None, None]:
172 """Yields ranges previously added to the tree that overlaps with
176 to_find: the interval with which to find all overlaps.
179 A (potentially empty) sequence of all ranges in the tree
180 that overlap with the argument.
182 >>> tree = AugmentedIntervalTree()
183 >>> tree.insert(NumericRange(20, 24))
184 >>> tree.insert(NumericRange(18, 22))
185 >>> tree.insert(NumericRange(14, 16))
186 >>> tree.insert(NumericRange(1, 30))
187 >>> tree.insert(NumericRange(25, 30))
188 >>> tree.insert(NumericRange(29, 33))
189 >>> tree.insert(NumericRange(5, 12))
190 >>> tree.insert(NumericRange(1, 6))
191 >>> tree.insert(NumericRange(13, 18))
192 >>> tree.insert(NumericRange(16, 28))
193 >>> tree.insert(NumericRange(21, 27))
194 >>> for x in tree.find_all_overlaps(NumericRange(19, 21)):
202 >>> del tree[NumericRange(1, 30)]
203 >>> for x in tree.find_all_overlaps(NumericRange(19, 21)):
211 if self.root is None:
213 yield from self._find_all_overlaps(self.root, to_find)
215 def _find_all_overlaps(
216 self, root: bst.Node, x: NumericRange
217 ) -> Generator[NumericRange, None, None]:
221 rv = AugmentedIntervalTree._assert_value_must_be_range(root.value)
222 if rv.overlaps_with(x):
226 lv = AugmentedIntervalTree._assert_value_must_be_range(root.left.value)
227 if lv.highest_in_subtree >= x.low:
228 yield from self._find_all_overlaps(root.left, x)
231 rv = AugmentedIntervalTree._assert_value_must_be_range(root.right.value)
232 if rv.highest_in_subtree >= x.low:
233 yield from self._find_all_overlaps(root.right, x)
236 if __name__ == "__main__":