3 """This is an augmented interval tree for storing ranges and identifying overlaps as
4 described by: https://en.wikipedia.org/wiki/Interval_tree.
7 from __future__ import annotations
9 from functools import total_ordering
10 from typing import Any, Optional, Union
12 from overrides import overrides
14 from pyutils.collectionz import bst
16 Numeric = Union[int, float]
20 class NumericRange(object):
21 def __init__(self, low: Numeric, high: Numeric):
26 self.low: Numeric = low
27 self.high: Numeric = high
28 self.highest_in_subtree: Numeric = high
30 def __lt__(self, other: NumericRange) -> bool:
31 return self.low < other.low
34 def __eq__(self, other: object) -> bool:
35 if not isinstance(other, NumericRange):
37 return self.low == other.low and self.high == other.high
39 def overlaps_with(self, other: NumericRange) -> bool:
40 return self.low <= other.high and self.high >= other.low
42 def __repr__(self) -> str:
43 return f"{self.low}..{self.high}"
46 class AugmentedIntervalTree(bst.BinarySearchTree):
51 def assert_value_must_be_range(value: Any) -> None:
52 if not isinstance(value, NumericRange):
54 "AugmentedIntervalTree expects to use NumericRanges, see bst for a "
55 + "general purpose tree usable for other types."
59 def _on_insert(self, parent: Optional[bst.Node], new: bst.Node) -> None:
60 AugmentedIntervalTree.assert_value_must_be_range(new.value)
61 for ancestor in self.parent_path(new):
63 if new.value.high > ancestor.value.highest_in_subtree:
64 ancestor.value.highest_in_subtree = new.value.high
67 def _on_delete(self, parent: Optional[bst.Node], deleted: bst.Node) -> None:
69 new_highest_candidates = []
71 new_highest_candidates.append(deleted.left.value.highest_in_subtree)
73 new_highest_candidates.append(deleted.right.value.highest_in_subtree)
74 if len(new_highest_candidates):
75 parent.value.highest_in_subtree = max(
76 parent.value.high, max(new_highest_candidates)
79 parent.value.highest_in_subtree = parent.value.high
81 def find_overlaps(self, x: NumericRange):
82 """Yields ranges previously added to the tree that x overlaps with.
84 >>> tree = AugmentedIntervalTree()
85 >>> tree.insert(NumericRange(20, 24))
86 >>> tree.insert(NumericRange(18, 22))
87 >>> tree.insert(NumericRange(14, 16))
88 >>> tree.insert(NumericRange(1, 30))
89 >>> tree.insert(NumericRange(25, 30))
90 >>> tree.insert(NumericRange(29, 33))
91 >>> tree.insert(NumericRange(5, 12))
92 >>> tree.insert(NumericRange(1, 6))
93 >>> tree.insert(NumericRange(13, 18))
94 >>> tree.insert(NumericRange(16, 28))
95 >>> tree.insert(NumericRange(21, 27))
96 >>> for x in tree.find_overlaps(NumericRange(19, 21)):
104 if self.root is None:
106 yield from self._find_overlaps(self.root, x)
108 def _find_overlaps(self, root: bst.Node, x: NumericRange):
109 """It's known that two intervals A and B overlap only
110 when both A.low <= B.high and A.high >= B.low. When
111 searching the trees for nodes overlapping with a given
112 interval, we can immediately skip:
114 * all nodes to the right of nodes whose low value is past
115 the end of the given interval and
116 * all nodes that have their maximum high value below the
117 start of the given interval.
122 if root.value.overlaps_with(x):
126 if root.left.value.highest_in_subtree >= x.low:
127 yield from self._find_overlaps(root.left, x)
129 if root.value.low <= x.high:
131 yield from self._find_overlaps(root.right, x)
134 if __name__ == "__main__":