return self.root
+ def _on_insert(self, parent: Optional[Node], new: Node) -> None:
+ """This is called immediately _after_ a new node is inserted."""
+ pass
+
def insert(self, value: Any) -> None:
"""
Insert something into the tree.
if self.root is None:
self.root = Node(value)
self.count = 1
+ self._on_insert(None, self.root)
else:
self._insert(value, self.root)
else:
node.left = Node(value)
self.count += 1
+ self._on_insert(node, node.left)
else:
if node.right is not None:
self._insert(value, node.right)
else:
node.right = Node(value)
self.count += 1
+ self._on_insert(node, node.right)
def __getitem__(self, value: Any) -> Optional[Node]:
"""
return ret
return False
+ def _on_delete(self, parent: Optional[Node], deleted: Node) -> None:
+ """This is called just before deleted is deleted --
+ i.e. before the tree changes."""
+ pass
+
def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool:
"""Delete helper"""
if node.value == value:
+
# Deleting a leaf node
if node.left is None and node.right is None:
+ self._on_delete(parent, node)
if parent is not None:
if parent.left == node:
parent.left = None
# Node only has a right.
elif node.left is None:
assert node.right is not None
+ self._on_delete(parent, node)
if parent is not None:
if parent.left == node:
parent.left = node.right
# Node only has a left.
elif node.right is None:
assert node.left is not None
+ self._on_delete(parent, node)
if parent is not None:
if parent.left == node:
parent.left = node.left
has_right_sibling: bool,
) -> str:
if node is not None:
- viz = f'\n{padding}{pointer}{node.value}'
+ viz = f"\n{padding}{pointer}{node.value}"
if has_right_sibling:
padding += "│ "
else:
- padding += ' '
+ padding += " "
pointer_right = "└──"
if node.right is not None:
if self.root is None:
return ""
- ret = f'{self.root.value}'
+ ret = f"{self.root.value}"
pointer_right = "└──"
if self.root.right is None:
pointer_left = "└──"
pointer_left = "├──"
ret += self.repr_traverse(
- '', pointer_left, self.root.left, self.root.left is not None
+ "", pointer_left, self.root.left, self.root.left is not None
)
- ret += self.repr_traverse('', pointer_right, self.root.right, False)
+ ret += self.repr_traverse("", pointer_right, self.root.right, False)
return ret
-if __name__ == '__main__':
+if __name__ == "__main__":
import doctest
doctest.testmod()
--- /dev/null
+#!/usr/bin/env python3
+
+"""This is an augmented interval tree for storing ranges and identifying overlaps as
+described by: https://en.wikipedia.org/wiki/Interval_tree.
+"""
+
+from __future__ import annotations
+
+from functools import total_ordering
+from typing import Any, Optional, Union
+
+from overrides import overrides
+
+from pyutils.collectionz import bst
+
+Numeric = Union[int, float]
+
+
+@total_ordering
+class NumericRange(object):
+ def __init__(self, low: Numeric, high: Numeric):
+ if low > high:
+ temp: Numeric = low
+ low = high
+ high = temp
+ self.low: Numeric = low
+ self.high: Numeric = high
+ self.highest_in_subtree: Numeric = high
+
+ def __lt__(self, other: NumericRange) -> bool:
+ return self.low < other.low
+
+ @overrides
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, NumericRange):
+ return False
+ return self.low == other.low and self.high == other.high
+
+ def overlaps_with(self, other: NumericRange) -> bool:
+ return self.low <= other.high and self.high >= other.low
+
+ def __repr__(self) -> str:
+ return f"{self.low}..{self.high}"
+
+
+class AugmentedIntervalTree(bst.BinarySearchTree):
+ def __init__(self):
+ super().__init__()
+
+ @staticmethod
+ def assert_value_must_be_range(value: Any) -> None:
+ if not isinstance(value, NumericRange):
+ raise Exception(
+ "AugmentedIntervalTree expects to use NumericRanges, see bst for a "
+ + "general purpose tree usable for other types."
+ )
+
+ @overrides
+ def _on_insert(self, parent: Optional[bst.Node], new: bst.Node) -> None:
+ AugmentedIntervalTree.assert_value_must_be_range(new.value)
+ for ancestor in self.parent_path(new):
+ assert ancestor
+ if new.value.high > ancestor.value.highest_in_subtree:
+ ancestor.value.highest_in_subtree = new.value.high
+
+ @overrides
+ def _on_delete(self, parent: Optional[bst.Node], deleted: bst.Node) -> None:
+ if parent:
+ new_highest_candidates = []
+ if deleted.left:
+ new_highest_candidates.append(deleted.left.value.highest_in_subtree)
+ if deleted.right:
+ new_highest_candidates.append(deleted.right.value.highest_in_subtree)
+ if len(new_highest_candidates):
+ parent.value.highest_in_subtree = max(
+ parent.value.high, max(new_highest_candidates)
+ )
+ else:
+ parent.value.highest_in_subtree = parent.value.high
+
+ def find_overlaps(self, x: NumericRange):
+ """Yields ranges previously added to the tree that x overlaps with.
+
+ >>> tree = AugmentedIntervalTree()
+ >>> tree.insert(NumericRange(20, 24))
+ >>> tree.insert(NumericRange(18, 22))
+ >>> tree.insert(NumericRange(14, 16))
+ >>> tree.insert(NumericRange(1, 30))
+ >>> tree.insert(NumericRange(25, 30))
+ >>> tree.insert(NumericRange(29, 33))
+ >>> tree.insert(NumericRange(5, 12))
+ >>> tree.insert(NumericRange(1, 6))
+ >>> tree.insert(NumericRange(13, 18))
+ >>> tree.insert(NumericRange(16, 28))
+ >>> tree.insert(NumericRange(21, 27))
+ >>> for x in tree.find_overlaps(NumericRange(19, 21)):
+ ... print(x)
+ 20..24
+ 18..22
+ 1..30
+ 16..28
+ 21..27
+ """
+ if self.root is None:
+ return
+ yield from self._find_overlaps(self.root, x)
+
+ def _find_overlaps(self, root: bst.Node, x: NumericRange):
+ """It's known that two intervals A and B overlap only
+ when both A.low <= B.high and A.high >= B.low. When
+ searching the trees for nodes overlapping with a given
+ interval, we can immediately skip:
+
+ * all nodes to the right of nodes whose low value is past
+ the end of the given interval and
+ * all nodes that have their maximum high value below the
+ start of the given interval.
+ """
+ if root is None:
+ return
+
+ if root.value.overlaps_with(x):
+ yield root.value
+
+ if root.left:
+ if root.left.value.highest_in_subtree >= x.low:
+ yield from self._find_overlaps(root.left, x)
+
+ if root.value.low <= x.high:
+ if root.right:
+ yield from self._find_overlaps(root.right, x)
+
+
+if __name__ == "__main__":
+ import doctest
+
+ doctest.testmod()