Update docs again.
[pyutils.git] / src / pyutils / collectionz / interval_tree.py
1 #!/usr/bin/env python3
2
3 """This is an augmented interval tree for storing ranges and identifying overlaps as
4 described by: https://en.wikipedia.org/wiki/Interval_tree.
5 """
6
7 from __future__ import annotations
8
9 from functools import total_ordering
10 from typing import Any, Generator, Optional, Union
11
12 from overrides import overrides
13
14 from pyutils.collectionz import bst
15
16 Numeric = Union[int, float]
17
18
19 @total_ordering
20 class NumericRange(object):
21     """Essentially a tuple of numbers denoting a range with some added
22     helper methods on it."""
23
24     def __init__(self, low: Numeric, high: Numeric):
25         if low > high:
26             temp: Numeric = low
27             low = high
28             high = temp
29         self.low: Numeric = low
30         self.high: Numeric = high
31         self.highest_in_subtree: Numeric = high
32
33     def __lt__(self, other: NumericRange) -> bool:
34         return self.low < other.low
35
36     @overrides
37     def __eq__(self, other: object) -> bool:
38         if not isinstance(other, NumericRange):
39             return False
40         return self.low == other.low and self.high == other.high
41
42     def overlaps_with(self, other: NumericRange) -> bool:
43         return self.low <= other.high and self.high >= other.low
44
45     def __repr__(self) -> str:
46         return f"{self.low}..{self.high}"
47
48
49 class AugmentedIntervalTree(bst.BinarySearchTree):
50     @staticmethod
51     def _assert_value_must_be_range(value: Any) -> None:
52         if not isinstance(value, NumericRange):
53             raise Exception(
54                 "AugmentedIntervalTree expects to use NumericRanges, see bst for a "
55                 + "general purpose tree usable for other types."
56             )
57
58     @overrides
59     def _on_insert(self, parent: Optional[bst.Node], new: bst.Node) -> None:
60         AugmentedIntervalTree._assert_value_must_be_range(new.value)
61         for ancestor in self.parent_path(new):
62             assert ancestor
63             if new.value.high > ancestor.value.highest_in_subtree:
64                 ancestor.value.highest_in_subtree = new.value.high
65
66     @overrides
67     def _on_delete(self, parent: Optional[bst.Node], deleted: bst.Node) -> None:
68         if parent:
69             new_highest_candidates = [parent.value.high]
70             if parent.left:
71                 new_highest_candidates.append(parent.left.value.highest_in_subtree)
72             if parent.right:
73                 new_highest_candidates.append(parent.right.value.highest_in_subtree)
74             parent.value.highest_in_subtree = max(new_highest_candidates)
75
76     def find_one_overlap(self, to_find: NumericRange) -> Optional[NumericRange]:
77         """Identify and return one overlapping node from the tree.
78
79         Args:
80             to_find: the interval with which to find an overlap.
81
82         Returns:
83             An overlapping range from the tree or None if no such
84             ranges exist in the tree at present.
85
86         >>> tree = AugmentedIntervalTree()
87         >>> tree.insert(NumericRange(20, 24))
88         >>> tree.insert(NumericRange(18, 22))
89         >>> tree.insert(NumericRange(14, 16))
90         >>> tree.insert(NumericRange(1, 30))
91         >>> tree.insert(NumericRange(25, 30))
92         >>> tree.insert(NumericRange(29, 33))
93         >>> tree.insert(NumericRange(5, 12))
94         >>> tree.insert(NumericRange(1, 6))
95         >>> tree.insert(NumericRange(13, 18))
96         >>> tree.insert(NumericRange(16, 28))
97         >>> tree.insert(NumericRange(21, 27))
98         >>> tree.find_one_overlap(NumericRange(6, 7))
99         1..30
100
101         """
102         return self._find_one_overlap(self.root, to_find)
103
104     def _find_one_overlap(
105         self, root: bst.Node, x: NumericRange
106     ) -> Optional[NumericRange]:
107         if root is None:
108             return None
109
110         if root.value.overlaps_with(x):
111             return root.value
112
113         if root.left:
114             if root.left.value.highest_in_subtree >= x.low:
115                 return self._find_one_overlap(root.left, x)
116
117         if root.right:
118             return self._find_one_overlap(root.right, x)
119         return None
120
121     def find_all_overlaps(
122         self, to_find: NumericRange
123     ) -> Generator[NumericRange, None, None]:
124         """Yields ranges previously added to the tree that overlaps with
125         to_find argument.
126
127         Args:
128             to_find: the interval with which to find all overlaps.
129
130         Returns:
131             A (potentially empty) sequence of all ranges in the tree
132             that overlap with the argument.
133
134         >>> tree = AugmentedIntervalTree()
135         >>> tree.insert(NumericRange(20, 24))
136         >>> tree.insert(NumericRange(18, 22))
137         >>> tree.insert(NumericRange(14, 16))
138         >>> tree.insert(NumericRange(1, 30))
139         >>> tree.insert(NumericRange(25, 30))
140         >>> tree.insert(NumericRange(29, 33))
141         >>> tree.insert(NumericRange(5, 12))
142         >>> tree.insert(NumericRange(1, 6))
143         >>> tree.insert(NumericRange(13, 18))
144         >>> tree.insert(NumericRange(16, 28))
145         >>> tree.insert(NumericRange(21, 27))
146         >>> for x in tree.find_all_overlaps(NumericRange(19, 21)):
147         ...     print(x)
148         20..24
149         18..22
150         1..30
151         16..28
152         21..27
153
154         >>> del tree[NumericRange(1, 30)]
155         >>> for x in tree.find_all_overlaps(NumericRange(19, 21)):
156         ...     print(x)
157         20..24
158         18..22
159         16..28
160         21..27
161
162         """
163         if self.root is None:
164             return
165         yield from self._find_all_overlaps(self.root, to_find)
166
167     def _find_all_overlaps(
168         self, root: bst.Node, x: NumericRange
169     ) -> Generator[NumericRange, None, None]:
170         if root is None:
171             return None
172
173         if root.value.overlaps_with(x):
174             yield root.value
175
176         if root.left:
177             if root.left.value.highest_in_subtree >= x.low:
178                 yield from self._find_all_overlaps(root.left, x)
179
180         if root.right:
181             if root.right.value.highest_in_subtree >= x.low:
182                 yield from self._find_all_overlaps(root.right, x)
183
184
185 if __name__ == "__main__":
186     import doctest
187
188     doctest.testmod()