3 # © Copyright 2021-2023, Scott Gasch
5 """This is an augmented interval tree for storing ranges and identifying overlaps as
6 described by: https://en.wikipedia.org/wiki/Interval_tree.
9 from __future__ import annotations
11 from functools import total_ordering
12 from typing import Any, Generator, Optional
14 from overrides import overrides
16 from pyutils.collectionz import bst
17 from pyutils.typez.simple import Numeric
21 class NumericRange(object):
22 """Essentially a tuple of numbers denoting a range with some added
23 helper methods on it."""
25 def __init__(self, low: Numeric, high: Numeric):
26 """Creates a NumericRange.
29 low: the lowest point in the range (inclusive).
30 high: the highest point in the range (inclusive).
34 If low > high this code swaps the parameters and keeps the range
42 self.low: Numeric = low
43 self.high: Numeric = high
44 self.highest_in_subtree: Numeric = high
46 def __lt__(self, other: NumericRange) -> bool:
49 True is this range is less than (lower low) other, else False.
51 if self.low != other.low:
52 return self.low < other.low
54 return self.high < other.high
57 def __eq__(self, other: object) -> bool:
60 True if this is the same range as other, else False.
62 if not isinstance(other, NumericRange):
64 return self.low == other.low and self.high == other.high
66 def overlaps_with(self, other: NumericRange) -> bool:
69 True if this NumericRange overlaps with other, else False.
71 return self.low <= other.high and self.high >= other.low
73 def __repr__(self) -> str:
74 return f"{self.low}..{self.high}"
77 class AugmentedIntervalTree(bst.BinarySearchTree):
79 def _assert_value_must_be_range(value: Any) -> None:
80 if not isinstance(value, NumericRange):
82 "AugmentedIntervalTree expects to use NumericRanges, see bst for a "
83 + "general purpose tree usable for other types."
87 def _on_insert(self, parent: Optional[bst.Node], new: bst.Node) -> None:
88 AugmentedIntervalTree._assert_value_must_be_range(new.value)
89 for ancestor in self.parent_path(new):
91 if new.value.high > ancestor.value.highest_in_subtree:
92 ancestor.value.highest_in_subtree = new.value.high
95 def _on_delete(self, parent: Optional[bst.Node], deleted: bst.Node) -> None:
97 new_highest_candidates = [parent.value.high]
99 new_highest_candidates.append(parent.left.value.highest_in_subtree)
101 new_highest_candidates.append(parent.right.value.highest_in_subtree)
102 parent.value.highest_in_subtree = max(new_highest_candidates)
104 def find_one_overlap(self, to_find: NumericRange) -> Optional[NumericRange]:
105 """Identify and return one overlapping node from the tree.
108 to_find: the interval with which to find an overlap.
111 An overlapping range from the tree or None if no such
112 ranges exist in the tree at present.
114 >>> tree = AugmentedIntervalTree()
115 >>> tree.insert(NumericRange(20, 24))
116 >>> tree.insert(NumericRange(18, 22))
117 >>> tree.insert(NumericRange(14, 16))
118 >>> tree.insert(NumericRange(1, 30))
119 >>> tree.insert(NumericRange(25, 30))
120 >>> tree.insert(NumericRange(29, 33))
121 >>> tree.insert(NumericRange(5, 12))
122 >>> tree.insert(NumericRange(1, 6))
123 >>> tree.insert(NumericRange(13, 18))
124 >>> tree.insert(NumericRange(16, 28))
125 >>> tree.insert(NumericRange(21, 27))
126 >>> tree.find_one_overlap(NumericRange(6, 7))
130 return self._find_one_overlap(self.root, to_find)
132 def _find_one_overlap(
133 self, root: bst.Node, x: NumericRange
134 ) -> Optional[NumericRange]:
138 if root.value.overlaps_with(x):
142 if root.left.value.highest_in_subtree >= x.low:
143 return self._find_one_overlap(root.left, x)
146 return self._find_one_overlap(root.right, x)
149 def find_all_overlaps(
150 self, to_find: NumericRange
151 ) -> Generator[NumericRange, None, None]:
152 """Yields ranges previously added to the tree that overlaps with
156 to_find: the interval with which to find all overlaps.
159 A (potentially empty) sequence of all ranges in the tree
160 that overlap with the argument.
162 >>> tree = AugmentedIntervalTree()
163 >>> tree.insert(NumericRange(20, 24))
164 >>> tree.insert(NumericRange(18, 22))
165 >>> tree.insert(NumericRange(14, 16))
166 >>> tree.insert(NumericRange(1, 30))
167 >>> tree.insert(NumericRange(25, 30))
168 >>> tree.insert(NumericRange(29, 33))
169 >>> tree.insert(NumericRange(5, 12))
170 >>> tree.insert(NumericRange(1, 6))
171 >>> tree.insert(NumericRange(13, 18))
172 >>> tree.insert(NumericRange(16, 28))
173 >>> tree.insert(NumericRange(21, 27))
174 >>> for x in tree.find_all_overlaps(NumericRange(19, 21)):
182 >>> del tree[NumericRange(1, 30)]
183 >>> for x in tree.find_all_overlaps(NumericRange(19, 21)):
191 if self.root is None:
193 yield from self._find_all_overlaps(self.root, to_find)
195 def _find_all_overlaps(
196 self, root: bst.Node, x: NumericRange
197 ) -> Generator[NumericRange, None, None]:
201 if root.value.overlaps_with(x):
205 if root.left.value.highest_in_subtree >= x.low:
206 yield from self._find_all_overlaps(root.left, x)
209 if root.right.value.highest_in_subtree >= x.low:
210 yield from self._find_all_overlaps(root.right, x)
213 if __name__ == "__main__":