More docs.
[pyutils.git] / src / pyutils / collectionz / interval_tree.py
1 #!/usr/bin/env python3
2
3 """This is an augmented interval tree for storing ranges and identifying overlaps as
4 described by: https://en.wikipedia.org/wiki/Interval_tree.
5 """
6
7 from __future__ import annotations
8
9 from functools import total_ordering
10 from typing import Any, Generator, Optional, Union
11
12 from overrides import overrides
13
14 from pyutils.collectionz import bst
15
16 Numeric = Union[int, float]
17
18
19 @total_ordering
20 class NumericRange(object):
21     """Essentially a tuple of numbers denoting a range with some added
22     helper methods on it."""
23
24     def __init__(self, low: Numeric, high: Numeric):
25         """Creates a NumericRange.
26
27         Args:
28             low: the lowest point in the range (inclusive).
29             high: the highest point in the range (inclusive).
30
31         .. warning::
32
33             If low > high this code swaps the parameters and keeps the range
34             rather than raising.
35
36         """
37         if low > high:
38             temp: Numeric = low
39             low = high
40             high = temp
41         self.low: Numeric = low
42         self.high: Numeric = high
43         self.highest_in_subtree: Numeric = high
44
45     def __lt__(self, other: NumericRange) -> bool:
46         return self.low < other.low
47
48     @overrides
49     def __eq__(self, other: object) -> bool:
50         if not isinstance(other, NumericRange):
51             return False
52         return self.low == other.low and self.high == other.high
53
54     def overlaps_with(self, other: NumericRange) -> bool:
55         return self.low <= other.high and self.high >= other.low
56
57     def __repr__(self) -> str:
58         return f"{self.low}..{self.high}"
59
60
61 class AugmentedIntervalTree(bst.BinarySearchTree):
62     @staticmethod
63     def _assert_value_must_be_range(value: Any) -> None:
64         if not isinstance(value, NumericRange):
65             raise Exception(
66                 "AugmentedIntervalTree expects to use NumericRanges, see bst for a "
67                 + "general purpose tree usable for other types."
68             )
69
70     @overrides
71     def _on_insert(self, parent: Optional[bst.Node], new: bst.Node) -> None:
72         AugmentedIntervalTree._assert_value_must_be_range(new.value)
73         for ancestor in self.parent_path(new):
74             assert ancestor
75             if new.value.high > ancestor.value.highest_in_subtree:
76                 ancestor.value.highest_in_subtree = new.value.high
77
78     @overrides
79     def _on_delete(self, parent: Optional[bst.Node], deleted: bst.Node) -> None:
80         if parent:
81             new_highest_candidates = [parent.value.high]
82             if parent.left:
83                 new_highest_candidates.append(parent.left.value.highest_in_subtree)
84             if parent.right:
85                 new_highest_candidates.append(parent.right.value.highest_in_subtree)
86             parent.value.highest_in_subtree = max(new_highest_candidates)
87
88     def find_one_overlap(self, to_find: NumericRange) -> Optional[NumericRange]:
89         """Identify and return one overlapping node from the tree.
90
91         Args:
92             to_find: the interval with which to find an overlap.
93
94         Returns:
95             An overlapping range from the tree or None if no such
96             ranges exist in the tree at present.
97
98         >>> tree = AugmentedIntervalTree()
99         >>> tree.insert(NumericRange(20, 24))
100         >>> tree.insert(NumericRange(18, 22))
101         >>> tree.insert(NumericRange(14, 16))
102         >>> tree.insert(NumericRange(1, 30))
103         >>> tree.insert(NumericRange(25, 30))
104         >>> tree.insert(NumericRange(29, 33))
105         >>> tree.insert(NumericRange(5, 12))
106         >>> tree.insert(NumericRange(1, 6))
107         >>> tree.insert(NumericRange(13, 18))
108         >>> tree.insert(NumericRange(16, 28))
109         >>> tree.insert(NumericRange(21, 27))
110         >>> tree.find_one_overlap(NumericRange(6, 7))
111         1..30
112
113         """
114         return self._find_one_overlap(self.root, to_find)
115
116     def _find_one_overlap(
117         self, root: bst.Node, x: NumericRange
118     ) -> Optional[NumericRange]:
119         if root is None:
120             return None
121
122         if root.value.overlaps_with(x):
123             return root.value
124
125         if root.left:
126             if root.left.value.highest_in_subtree >= x.low:
127                 return self._find_one_overlap(root.left, x)
128
129         if root.right:
130             return self._find_one_overlap(root.right, x)
131         return None
132
133     def find_all_overlaps(
134         self, to_find: NumericRange
135     ) -> Generator[NumericRange, None, None]:
136         """Yields ranges previously added to the tree that overlaps with
137         to_find argument.
138
139         Args:
140             to_find: the interval with which to find all overlaps.
141
142         Returns:
143             A (potentially empty) sequence of all ranges in the tree
144             that overlap with the argument.
145
146         >>> tree = AugmentedIntervalTree()
147         >>> tree.insert(NumericRange(20, 24))
148         >>> tree.insert(NumericRange(18, 22))
149         >>> tree.insert(NumericRange(14, 16))
150         >>> tree.insert(NumericRange(1, 30))
151         >>> tree.insert(NumericRange(25, 30))
152         >>> tree.insert(NumericRange(29, 33))
153         >>> tree.insert(NumericRange(5, 12))
154         >>> tree.insert(NumericRange(1, 6))
155         >>> tree.insert(NumericRange(13, 18))
156         >>> tree.insert(NumericRange(16, 28))
157         >>> tree.insert(NumericRange(21, 27))
158         >>> for x in tree.find_all_overlaps(NumericRange(19, 21)):
159         ...     print(x)
160         20..24
161         18..22
162         1..30
163         16..28
164         21..27
165
166         >>> del tree[NumericRange(1, 30)]
167         >>> for x in tree.find_all_overlaps(NumericRange(19, 21)):
168         ...     print(x)
169         20..24
170         18..22
171         16..28
172         21..27
173
174         """
175         if self.root is None:
176             return
177         yield from self._find_all_overlaps(self.root, to_find)
178
179     def _find_all_overlaps(
180         self, root: bst.Node, x: NumericRange
181     ) -> Generator[NumericRange, None, None]:
182         if root is None:
183             return None
184
185         if root.value.overlaps_with(x):
186             yield root.value
187
188         if root.left:
189             if root.left.value.highest_in_subtree >= x.low:
190                 yield from self._find_all_overlaps(root.left, x)
191
192         if root.right:
193             if root.right.value.highest_in_subtree >= x.low:
194                 yield from self._find_all_overlaps(root.right, x)
195
196
197 if __name__ == "__main__":
198     import doctest
199
200     doctest.testmod()