Merge simple and typing.
[pyutils.git] / src / pyutils / collectionz / interval_tree.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2023, Scott Gasch
4
5 """This is an augmented interval tree for storing ranges and identifying overlaps as
6 described by: https://en.wikipedia.org/wiki/Interval_tree.
7 """
8
9 from __future__ import annotations
10
11 from functools import total_ordering
12 from typing import Any, Generator, Optional
13
14 from overrides import overrides
15
16 from pyutils.collectionz import bst
17 from pyutils.typez.typing import Numeric
18
19
20 @total_ordering
21 class NumericRange(bst.Comparable):
22     """Essentially a tuple of numbers denoting a range with some added
23     helper methods on it."""
24
25     def __init__(self, low: Numeric, high: Numeric):
26         """Creates a NumericRange.
27
28         Args:
29             low: the lowest point in the range (inclusive).
30             high: the highest point in the range (inclusive).
31
32         .. warning::
33
34             If low > high this code swaps the parameters and keeps the range
35             rather than raising.
36
37         """
38         if low > high:
39             low, high = high, low
40         self.low: Numeric = low
41         self.high: Numeric = high
42         self.highest_in_subtree: Numeric = high
43
44     @overrides
45     def __lt__(self, other: NumericRange) -> bool:
46         """
47         Returns:
48             True is this range is less than (lower low) other, else False.
49         """
50         if self.low != other.low:
51             return self.low < other.low
52         else:
53             return self.high < other.high
54
55     @overrides
56     def __eq__(self, other: object) -> bool:
57         """
58         Returns:
59             True if this is the same range as other, else False.
60         """
61         if not isinstance(other, NumericRange):
62             return False
63         return self.low == other.low and self.high == other.high
64
65     @overrides
66     def __le__(self, other: object) -> bool:
67         if not isinstance(other, NumericRange):
68             return False
69         return self < other or self == other
70
71     def overlaps_with(self, other: NumericRange) -> bool:
72         """
73         Returns:
74             True if this NumericRange overlaps with other, else False.
75         """
76         return self.low <= other.high and self.high >= other.low
77
78     def __repr__(self) -> str:
79         return f"[{self.low}..{self.high}]"
80
81
82 class AugmentedIntervalTree(bst.BinarySearchTree):
83     @staticmethod
84     def _assert_value_must_be_range(value: Any) -> NumericRange:
85         if not isinstance(value, NumericRange):
86             raise Exception(
87                 "AugmentedIntervalTree expects to use NumericRanges, see bst for a "
88                 + "general purpose tree usable for other types."
89             )
90         return value
91
92     @overrides
93     def _on_insert(self, parent: Optional[bst.Node], new: bst.Node) -> None:
94         nv: NumericRange = AugmentedIntervalTree._assert_value_must_be_range(new.value)
95         for ancestor in self.parent_path(new):
96             assert ancestor
97             av: NumericRange = AugmentedIntervalTree._assert_value_must_be_range(
98                 ancestor.value
99             )
100             if nv.high > av.highest_in_subtree:
101                 av.highest_in_subtree = nv.high
102
103     @overrides
104     def _on_delete(self, parent: Optional[bst.Node], deleted: bst.Node) -> None:
105         if parent:
106             pv: NumericRange = AugmentedIntervalTree._assert_value_must_be_range(
107                 parent.value
108             )
109             new_highest_candidates = [pv.high]
110             if parent.left:
111                 lv: NumericRange = AugmentedIntervalTree._assert_value_must_be_range(
112                     parent.left.value
113                 )
114                 new_highest_candidates.append(lv.highest_in_subtree)
115             if parent.right:
116                 rv: NumericRange = AugmentedIntervalTree._assert_value_must_be_range(
117                     parent.right.value
118                 )
119                 new_highest_candidates.append(rv.highest_in_subtree)
120             pv.highest_in_subtree = max(new_highest_candidates)
121
122     def find_one_overlap(self, to_find: NumericRange) -> Optional[NumericRange]:
123         """Identify and return one overlapping node from the tree.
124
125         Args:
126             to_find: the interval with which to find an overlap.
127
128         Returns:
129             An overlapping range from the tree or None if no such
130             ranges exist in the tree at present.
131
132         >>> tree = AugmentedIntervalTree()
133         >>> tree.insert(NumericRange(20, 24))
134         >>> tree.insert(NumericRange(18, 22))
135         >>> tree.insert(NumericRange(14, 16))
136         >>> tree.insert(NumericRange(1, 30))
137         >>> tree.insert(NumericRange(25, 30))
138         >>> tree.insert(NumericRange(29, 33))
139         >>> tree.insert(NumericRange(5, 12))
140         >>> tree.insert(NumericRange(1, 6))
141         >>> tree.insert(NumericRange(13, 18))
142         >>> tree.insert(NumericRange(16, 28))
143         >>> tree.insert(NumericRange(21, 27))
144         >>> tree.find_one_overlap(NumericRange(6, 7))
145         [1..30]
146
147         """
148         return self._find_one_overlap(self.root, to_find)
149
150     def _find_one_overlap(
151         self, root: bst.Node, x: NumericRange
152     ) -> Optional[NumericRange]:
153         if root is None:
154             return None
155
156         rv = AugmentedIntervalTree._assert_value_must_be_range(root.value)
157         if rv.overlaps_with(x):
158             return rv
159
160         if root.left:
161             lv = AugmentedIntervalTree._assert_value_must_be_range(root.left.value)
162             if lv.highest_in_subtree >= x.low:
163                 return self._find_one_overlap(root.left, x)
164
165         if root.right:
166             return self._find_one_overlap(root.right, x)
167         return None
168
169     def find_all_overlaps(
170         self, to_find: NumericRange
171     ) -> Generator[NumericRange, None, None]:
172         """Yields ranges previously added to the tree that overlaps with
173         to_find argument.
174
175         Args:
176             to_find: the interval with which to find all overlaps.
177
178         Returns:
179             A (potentially empty) sequence of all ranges in the tree
180             that overlap with the argument.
181
182         >>> tree = AugmentedIntervalTree()
183         >>> tree.insert(NumericRange(20, 24))
184         >>> tree.insert(NumericRange(18, 22))
185         >>> tree.insert(NumericRange(14, 16))
186         >>> tree.insert(NumericRange(1, 30))
187         >>> tree.insert(NumericRange(25, 30))
188         >>> tree.insert(NumericRange(29, 33))
189         >>> tree.insert(NumericRange(5, 12))
190         >>> tree.insert(NumericRange(1, 6))
191         >>> tree.insert(NumericRange(13, 18))
192         >>> tree.insert(NumericRange(16, 28))
193         >>> tree.insert(NumericRange(21, 27))
194         >>> for x in tree.find_all_overlaps(NumericRange(19, 21)):
195         ...     print(x)
196         [20..24]
197         [18..22]
198         [1..30]
199         [16..28]
200         [21..27]
201
202         >>> del tree[NumericRange(1, 30)]
203         >>> for x in tree.find_all_overlaps(NumericRange(19, 21)):
204         ...     print(x)
205         [20..24]
206         [18..22]
207         [16..28]
208         [21..27]
209
210         """
211         if self.root is None:
212             return
213         yield from self._find_all_overlaps(self.root, to_find)
214
215     def _find_all_overlaps(
216         self, root: bst.Node, x: NumericRange
217     ) -> Generator[NumericRange, None, None]:
218         if root is None:
219             return None
220
221         rv = AugmentedIntervalTree._assert_value_must_be_range(root.value)
222         if rv.overlaps_with(x):
223             yield rv
224
225         if root.left:
226             lv = AugmentedIntervalTree._assert_value_must_be_range(root.left.value)
227             if lv.highest_in_subtree >= x.low:
228                 yield from self._find_all_overlaps(root.left, x)
229
230         if root.right:
231             rv = AugmentedIntervalTree._assert_value_must_be_range(root.right.value)
232             if rv.highest_in_subtree >= x.low:
233                 yield from self._find_all_overlaps(root.right, x)
234
235
236 if __name__ == "__main__":
237     import doctest
238
239     doctest.testmod()