Adds Graph.
[pyutils.git] / src / pyutils / collectionz / interval_tree.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """This is an augmented interval tree for storing ranges and identifying overlaps as
6 described by: https://en.wikipedia.org/wiki/Interval_tree.
7 """
8
9 from __future__ import annotations
10
11 from functools import total_ordering
12 from typing import Any, Generator, Optional
13
14 from overrides import overrides
15
16 from pyutils.collectionz import bst
17 from pyutils.types.simple import Numeric
18
19
20 @total_ordering
21 class NumericRange(object):
22     """Essentially a tuple of numbers denoting a range with some added
23     helper methods on it."""
24
25     def __init__(self, low: Numeric, high: Numeric):
26         """Creates a NumericRange.
27
28         Args:
29             low: the lowest point in the range (inclusive).
30             high: the highest point in the range (inclusive).
31
32         .. warning::
33
34             If low > high this code swaps the parameters and keeps the range
35             rather than raising.
36
37         """
38         if low > high:
39             temp: Numeric = low
40             low = high
41             high = temp
42         self.low: Numeric = low
43         self.high: Numeric = high
44         self.highest_in_subtree: Numeric = high
45
46     def __lt__(self, other: NumericRange) -> bool:
47         """
48         Returns:
49             True is this range is less than (lower low) other, else False.
50         """
51         if self.low != other.low:
52             return self.low < other.low
53         else:
54             return self.high < other.high
55
56     @overrides
57     def __eq__(self, other: object) -> bool:
58         """
59         Returns:
60             True if this is the same range as other, else False.
61         """
62         if not isinstance(other, NumericRange):
63             return False
64         return self.low == other.low and self.high == other.high
65
66     def overlaps_with(self, other: NumericRange) -> bool:
67         """
68         Returns:
69             True if this NumericRange overlaps with other, else False.
70         """
71         return self.low <= other.high and self.high >= other.low
72
73     def __repr__(self) -> str:
74         return f"{self.low}..{self.high}"
75
76
77 class AugmentedIntervalTree(bst.BinarySearchTree):
78     @staticmethod
79     def _assert_value_must_be_range(value: Any) -> None:
80         if not isinstance(value, NumericRange):
81             raise Exception(
82                 "AugmentedIntervalTree expects to use NumericRanges, see bst for a "
83                 + "general purpose tree usable for other types."
84             )
85
86     @overrides
87     def _on_insert(self, parent: Optional[bst.Node], new: bst.Node) -> None:
88         AugmentedIntervalTree._assert_value_must_be_range(new.value)
89         for ancestor in self.parent_path(new):
90             assert ancestor
91             if new.value.high > ancestor.value.highest_in_subtree:
92                 ancestor.value.highest_in_subtree = new.value.high
93
94     @overrides
95     def _on_delete(self, parent: Optional[bst.Node], deleted: bst.Node) -> None:
96         if parent:
97             new_highest_candidates = [parent.value.high]
98             if parent.left:
99                 new_highest_candidates.append(parent.left.value.highest_in_subtree)
100             if parent.right:
101                 new_highest_candidates.append(parent.right.value.highest_in_subtree)
102             parent.value.highest_in_subtree = max(new_highest_candidates)
103
104     def find_one_overlap(self, to_find: NumericRange) -> Optional[NumericRange]:
105         """Identify and return one overlapping node from the tree.
106
107         Args:
108             to_find: the interval with which to find an overlap.
109
110         Returns:
111             An overlapping range from the tree or None if no such
112             ranges exist in the tree at present.
113
114         >>> tree = AugmentedIntervalTree()
115         >>> tree.insert(NumericRange(20, 24))
116         >>> tree.insert(NumericRange(18, 22))
117         >>> tree.insert(NumericRange(14, 16))
118         >>> tree.insert(NumericRange(1, 30))
119         >>> tree.insert(NumericRange(25, 30))
120         >>> tree.insert(NumericRange(29, 33))
121         >>> tree.insert(NumericRange(5, 12))
122         >>> tree.insert(NumericRange(1, 6))
123         >>> tree.insert(NumericRange(13, 18))
124         >>> tree.insert(NumericRange(16, 28))
125         >>> tree.insert(NumericRange(21, 27))
126         >>> tree.find_one_overlap(NumericRange(6, 7))
127         1..30
128
129         """
130         return self._find_one_overlap(self.root, to_find)
131
132     def _find_one_overlap(
133         self, root: bst.Node, x: NumericRange
134     ) -> Optional[NumericRange]:
135         if root is None:
136             return None
137
138         if root.value.overlaps_with(x):
139             return root.value
140
141         if root.left:
142             if root.left.value.highest_in_subtree >= x.low:
143                 return self._find_one_overlap(root.left, x)
144
145         if root.right:
146             return self._find_one_overlap(root.right, x)
147         return None
148
149     def find_all_overlaps(
150         self, to_find: NumericRange
151     ) -> Generator[NumericRange, None, None]:
152         """Yields ranges previously added to the tree that overlaps with
153         to_find argument.
154
155         Args:
156             to_find: the interval with which to find all overlaps.
157
158         Returns:
159             A (potentially empty) sequence of all ranges in the tree
160             that overlap with the argument.
161
162         >>> tree = AugmentedIntervalTree()
163         >>> tree.insert(NumericRange(20, 24))
164         >>> tree.insert(NumericRange(18, 22))
165         >>> tree.insert(NumericRange(14, 16))
166         >>> tree.insert(NumericRange(1, 30))
167         >>> tree.insert(NumericRange(25, 30))
168         >>> tree.insert(NumericRange(29, 33))
169         >>> tree.insert(NumericRange(5, 12))
170         >>> tree.insert(NumericRange(1, 6))
171         >>> tree.insert(NumericRange(13, 18))
172         >>> tree.insert(NumericRange(16, 28))
173         >>> tree.insert(NumericRange(21, 27))
174         >>> for x in tree.find_all_overlaps(NumericRange(19, 21)):
175         ...     print(x)
176         20..24
177         18..22
178         1..30
179         16..28
180         21..27
181
182         >>> del tree[NumericRange(1, 30)]
183         >>> for x in tree.find_all_overlaps(NumericRange(19, 21)):
184         ...     print(x)
185         20..24
186         18..22
187         16..28
188         21..27
189
190         """
191         if self.root is None:
192             return
193         yield from self._find_all_overlaps(self.root, to_find)
194
195     def _find_all_overlaps(
196         self, root: bst.Node, x: NumericRange
197     ) -> Generator[NumericRange, None, None]:
198         if root is None:
199             return None
200
201         if root.value.overlaps_with(x):
202             yield root.value
203
204         if root.left:
205             if root.left.value.highest_in_subtree >= x.low:
206                 yield from self._find_all_overlaps(root.left, x)
207
208         if root.right:
209             if root.right.value.highest_in_subtree >= x.low:
210                 yield from self._find_all_overlaps(root.right, x)
211
212
213 if __name__ == "__main__":
214     import doctest
215
216     doctest.testmod()