Tiebreak ordering of ranges with the same lower bound using upper bound.
[pyutils.git] / src / pyutils / collectionz / interval_tree.py
1 #!/usr/bin/env python3
2
3 """This is an augmented interval tree for storing ranges and identifying overlaps as
4 described by: https://en.wikipedia.org/wiki/Interval_tree.
5 """
6
7 from __future__ import annotations
8
9 from functools import total_ordering
10 from typing import Any, Generator, Optional, Union
11
12 from overrides import overrides
13
14 from pyutils.collectionz import bst
15
16 Numeric = Union[int, float]
17
18
19 @total_ordering
20 class NumericRange(object):
21     """Essentially a tuple of numbers denoting a range with some added
22     helper methods on it."""
23
24     def __init__(self, low: Numeric, high: Numeric):
25         """Creates a NumericRange.
26
27         Args:
28             low: the lowest point in the range (inclusive).
29             high: the highest point in the range (inclusive).
30
31         .. warning::
32
33             If low > high this code swaps the parameters and keeps the range
34             rather than raising.
35
36         """
37         if low > high:
38             temp: Numeric = low
39             low = high
40             high = temp
41         self.low: Numeric = low
42         self.high: Numeric = high
43         self.highest_in_subtree: Numeric = high
44
45     def __lt__(self, other: NumericRange) -> bool:
46         """
47         Returns:
48             True is this range is less than (lower low) other, else False.
49         """
50         if self.low != other.low:
51             return self.low < other.low
52         else:
53             return self.high < other.high
54
55     @overrides
56     def __eq__(self, other: object) -> bool:
57         """
58         Returns:
59             True if this is the same range as other, else False.
60         """
61         if not isinstance(other, NumericRange):
62             return False
63         return self.low == other.low and self.high == other.high
64
65     def overlaps_with(self, other: NumericRange) -> bool:
66         """
67         Returns:
68             True if this NumericRange overlaps with other, else False.
69         """
70         return self.low <= other.high and self.high >= other.low
71
72     def __repr__(self) -> str:
73         return f"{self.low}..{self.high}"
74
75
76 class AugmentedIntervalTree(bst.BinarySearchTree):
77     @staticmethod
78     def _assert_value_must_be_range(value: Any) -> None:
79         if not isinstance(value, NumericRange):
80             raise Exception(
81                 "AugmentedIntervalTree expects to use NumericRanges, see bst for a "
82                 + "general purpose tree usable for other types."
83             )
84
85     @overrides
86     def _on_insert(self, parent: Optional[bst.Node], new: bst.Node) -> None:
87         AugmentedIntervalTree._assert_value_must_be_range(new.value)
88         for ancestor in self.parent_path(new):
89             assert ancestor
90             if new.value.high > ancestor.value.highest_in_subtree:
91                 ancestor.value.highest_in_subtree = new.value.high
92
93     @overrides
94     def _on_delete(self, parent: Optional[bst.Node], deleted: bst.Node) -> None:
95         if parent:
96             new_highest_candidates = [parent.value.high]
97             if parent.left:
98                 new_highest_candidates.append(parent.left.value.highest_in_subtree)
99             if parent.right:
100                 new_highest_candidates.append(parent.right.value.highest_in_subtree)
101             parent.value.highest_in_subtree = max(new_highest_candidates)
102
103     def find_one_overlap(self, to_find: NumericRange) -> Optional[NumericRange]:
104         """Identify and return one overlapping node from the tree.
105
106         Args:
107             to_find: the interval with which to find an overlap.
108
109         Returns:
110             An overlapping range from the tree or None if no such
111             ranges exist in the tree at present.
112
113         >>> tree = AugmentedIntervalTree()
114         >>> tree.insert(NumericRange(20, 24))
115         >>> tree.insert(NumericRange(18, 22))
116         >>> tree.insert(NumericRange(14, 16))
117         >>> tree.insert(NumericRange(1, 30))
118         >>> tree.insert(NumericRange(25, 30))
119         >>> tree.insert(NumericRange(29, 33))
120         >>> tree.insert(NumericRange(5, 12))
121         >>> tree.insert(NumericRange(1, 6))
122         >>> tree.insert(NumericRange(13, 18))
123         >>> tree.insert(NumericRange(16, 28))
124         >>> tree.insert(NumericRange(21, 27))
125         >>> tree.find_one_overlap(NumericRange(6, 7))
126         1..30
127
128         """
129         return self._find_one_overlap(self.root, to_find)
130
131     def _find_one_overlap(
132         self, root: bst.Node, x: NumericRange
133     ) -> Optional[NumericRange]:
134         if root is None:
135             return None
136
137         if root.value.overlaps_with(x):
138             return root.value
139
140         if root.left:
141             if root.left.value.highest_in_subtree >= x.low:
142                 return self._find_one_overlap(root.left, x)
143
144         if root.right:
145             return self._find_one_overlap(root.right, x)
146         return None
147
148     def find_all_overlaps(
149         self, to_find: NumericRange
150     ) -> Generator[NumericRange, None, None]:
151         """Yields ranges previously added to the tree that overlaps with
152         to_find argument.
153
154         Args:
155             to_find: the interval with which to find all overlaps.
156
157         Returns:
158             A (potentially empty) sequence of all ranges in the tree
159             that overlap with the argument.
160
161         >>> tree = AugmentedIntervalTree()
162         >>> tree.insert(NumericRange(20, 24))
163         >>> tree.insert(NumericRange(18, 22))
164         >>> tree.insert(NumericRange(14, 16))
165         >>> tree.insert(NumericRange(1, 30))
166         >>> tree.insert(NumericRange(25, 30))
167         >>> tree.insert(NumericRange(29, 33))
168         >>> tree.insert(NumericRange(5, 12))
169         >>> tree.insert(NumericRange(1, 6))
170         >>> tree.insert(NumericRange(13, 18))
171         >>> tree.insert(NumericRange(16, 28))
172         >>> tree.insert(NumericRange(21, 27))
173         >>> for x in tree.find_all_overlaps(NumericRange(19, 21)):
174         ...     print(x)
175         20..24
176         18..22
177         1..30
178         16..28
179         21..27
180
181         >>> del tree[NumericRange(1, 30)]
182         >>> for x in tree.find_all_overlaps(NumericRange(19, 21)):
183         ...     print(x)
184         20..24
185         18..22
186         16..28
187         21..27
188
189         """
190         if self.root is None:
191             return
192         yield from self._find_all_overlaps(self.root, to_find)
193
194     def _find_all_overlaps(
195         self, root: bst.Node, x: NumericRange
196     ) -> Generator[NumericRange, None, None]:
197         if root is None:
198             return None
199
200         if root.value.overlaps_with(x):
201             yield root.value
202
203         if root.left:
204             if root.left.value.highest_in_subtree >= x.low:
205                 yield from self._find_all_overlaps(root.left, x)
206
207         if root.right:
208             if root.right.value.highest_in_subtree >= x.low:
209                 yield from self._find_all_overlaps(root.right, x)
210
211
212 if __name__ == "__main__":
213     import doctest
214
215     doctest.testmod()