Adds IntervalTree.
[pyutils.git] / src / pyutils / collectionz / interval_tree.py
1 #!/usr/bin/env python3
2
3 """This is an augmented interval tree for storing ranges and identifying overlaps as
4 described by: https://en.wikipedia.org/wiki/Interval_tree.
5 """
6
7 from __future__ import annotations
8
9 from functools import total_ordering
10 from typing import Any, Optional, Union
11
12 from overrides import overrides
13
14 from pyutils.collectionz import bst
15
16 Numeric = Union[int, float]
17
18
19 @total_ordering
20 class NumericRange(object):
21     def __init__(self, low: Numeric, high: Numeric):
22         if low > high:
23             temp: Numeric = low
24             low = high
25             high = temp
26         self.low: Numeric = low
27         self.high: Numeric = high
28         self.highest_in_subtree: Numeric = high
29
30     def __lt__(self, other: NumericRange) -> bool:
31         return self.low < other.low
32
33     @overrides
34     def __eq__(self, other: object) -> bool:
35         if not isinstance(other, NumericRange):
36             return False
37         return self.low == other.low and self.high == other.high
38
39     def overlaps_with(self, other: NumericRange) -> bool:
40         return self.low <= other.high and self.high >= other.low
41
42     def __repr__(self) -> str:
43         return f"{self.low}..{self.high}"
44
45
46 class AugmentedIntervalTree(bst.BinarySearchTree):
47     def __init__(self):
48         super().__init__()
49
50     @staticmethod
51     def assert_value_must_be_range(value: Any) -> None:
52         if not isinstance(value, NumericRange):
53             raise Exception(
54                 "AugmentedIntervalTree expects to use NumericRanges, see bst for a "
55                 + "general purpose tree usable for other types."
56             )
57
58     @overrides
59     def _on_insert(self, parent: Optional[bst.Node], new: bst.Node) -> None:
60         AugmentedIntervalTree.assert_value_must_be_range(new.value)
61         for ancestor in self.parent_path(new):
62             assert ancestor
63             if new.value.high > ancestor.value.highest_in_subtree:
64                 ancestor.value.highest_in_subtree = new.value.high
65
66     @overrides
67     def _on_delete(self, parent: Optional[bst.Node], deleted: bst.Node) -> None:
68         if parent:
69             new_highest_candidates = []
70             if deleted.left:
71                 new_highest_candidates.append(deleted.left.value.highest_in_subtree)
72             if deleted.right:
73                 new_highest_candidates.append(deleted.right.value.highest_in_subtree)
74             if len(new_highest_candidates):
75                 parent.value.highest_in_subtree = max(
76                     parent.value.high, max(new_highest_candidates)
77                 )
78             else:
79                 parent.value.highest_in_subtree = parent.value.high
80
81     def find_overlaps(self, x: NumericRange):
82         """Yields ranges previously added to the tree that x overlaps with.
83
84         >>> tree = AugmentedIntervalTree()
85         >>> tree.insert(NumericRange(20, 24))
86         >>> tree.insert(NumericRange(18, 22))
87         >>> tree.insert(NumericRange(14, 16))
88         >>> tree.insert(NumericRange(1, 30))
89         >>> tree.insert(NumericRange(25, 30))
90         >>> tree.insert(NumericRange(29, 33))
91         >>> tree.insert(NumericRange(5, 12))
92         >>> tree.insert(NumericRange(1, 6))
93         >>> tree.insert(NumericRange(13, 18))
94         >>> tree.insert(NumericRange(16, 28))
95         >>> tree.insert(NumericRange(21, 27))
96         >>> for x in tree.find_overlaps(NumericRange(19, 21)):
97         ...     print(x)
98         20..24
99         18..22
100         1..30
101         16..28
102         21..27
103         """
104         if self.root is None:
105             return
106         yield from self._find_overlaps(self.root, x)
107
108     def _find_overlaps(self, root: bst.Node, x: NumericRange):
109         """It's known that two intervals A and B overlap only
110         when both A.low <= B.high and A.high >= B.low.  When
111         searching the trees for nodes overlapping with a given
112         interval, we can immediately skip:
113
114             * all nodes to the right of nodes whose low value is past
115               the end of the given interval and
116             * all nodes that have their maximum high value below the
117               start of the given interval.
118         """
119         if root is None:
120             return
121
122         if root.value.overlaps_with(x):
123             yield root.value
124
125         if root.left:
126             if root.left.value.highest_in_subtree >= x.low:
127                 yield from self._find_overlaps(root.left, x)
128
129         if root.value.low <= x.high:
130             if root.right:
131                 yield from self._find_overlaps(root.right, x)
132
133
134 if __name__ == "__main__":
135     import doctest
136
137     doctest.testmod()