From: Scott Gasch Date: Tue, 13 Dec 2022 22:36:23 +0000 (-0800) Subject: Adds IntervalTree. X-Git-Url: https://wannabe.guru.org/gitweb/?a=commitdiff_plain;h=f564702340f7528e2ad186e6a20033636d6afaef;p=pyutils.git Adds IntervalTree. --- diff --git a/src/pyutils/collectionz/bst.py b/src/pyutils/collectionz/bst.py index 4c0bacd..1efed52 100644 --- a/src/pyutils/collectionz/bst.py +++ b/src/pyutils/collectionz/bst.py @@ -36,6 +36,10 @@ class BinarySearchTree(object): return self.root + def _on_insert(self, parent: Optional[Node], new: Node) -> None: + """This is called immediately _after_ a new node is inserted.""" + pass + def insert(self, value: Any) -> None: """ Insert something into the tree. @@ -57,6 +61,7 @@ class BinarySearchTree(object): if self.root is None: self.root = Node(value) self.count = 1 + self._on_insert(None, self.root) else: self._insert(value, self.root) @@ -68,12 +73,14 @@ class BinarySearchTree(object): else: node.left = Node(value) self.count += 1 + self._on_insert(node, node.left) else: if node.right is not None: self._insert(value, node.right) else: node.right = Node(value) self.count += 1 + self._on_insert(node, node.right) def __getitem__(self, value: Any) -> Optional[Node]: """ @@ -264,11 +271,18 @@ class BinarySearchTree(object): return ret return False + def _on_delete(self, parent: Optional[Node], deleted: Node) -> None: + """This is called just before deleted is deleted -- + i.e. before the tree changes.""" + pass + def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool: """Delete helper""" if node.value == value: + # Deleting a leaf node if node.left is None and node.right is None: + self._on_delete(parent, node) if parent is not None: if parent.left == node: parent.left = None @@ -280,6 +294,7 @@ class BinarySearchTree(object): # Node only has a right. elif node.left is None: assert node.right is not None + self._on_delete(parent, node) if parent is not None: if parent.left == node: parent.left = node.right @@ -291,6 +306,7 @@ class BinarySearchTree(object): # Node only has a left. elif node.right is None: assert node.left is not None + self._on_delete(parent, node) if parent is not None: if parent.left == node: parent.left = node.left @@ -635,11 +651,11 @@ class BinarySearchTree(object): has_right_sibling: bool, ) -> str: if node is not None: - viz = f'\n{padding}{pointer}{node.value}' + viz = f"\n{padding}{pointer}{node.value}" if has_right_sibling: padding += "│ " else: - padding += ' ' + padding += " " pointer_right = "└──" if node.right is not None: @@ -679,7 +695,7 @@ class BinarySearchTree(object): if self.root is None: return "" - ret = f'{self.root.value}' + ret = f"{self.root.value}" pointer_right = "└──" if self.root.right is None: pointer_left = "└──" @@ -687,13 +703,13 @@ class BinarySearchTree(object): pointer_left = "├──" ret += self.repr_traverse( - '', pointer_left, self.root.left, self.root.left is not None + "", pointer_left, self.root.left, self.root.left is not None ) - ret += self.repr_traverse('', pointer_right, self.root.right, False) + ret += self.repr_traverse("", pointer_right, self.root.right, False) return ret -if __name__ == '__main__': +if __name__ == "__main__": import doctest doctest.testmod() diff --git a/src/pyutils/collectionz/interval_tree.py b/src/pyutils/collectionz/interval_tree.py new file mode 100644 index 0000000..733aea0 --- /dev/null +++ b/src/pyutils/collectionz/interval_tree.py @@ -0,0 +1,137 @@ +#!/usr/bin/env python3 + +"""This is an augmented interval tree for storing ranges and identifying overlaps as +described by: https://en.wikipedia.org/wiki/Interval_tree. +""" + +from __future__ import annotations + +from functools import total_ordering +from typing import Any, Optional, Union + +from overrides import overrides + +from pyutils.collectionz import bst + +Numeric = Union[int, float] + + +@total_ordering +class NumericRange(object): + def __init__(self, low: Numeric, high: Numeric): + if low > high: + temp: Numeric = low + low = high + high = temp + self.low: Numeric = low + self.high: Numeric = high + self.highest_in_subtree: Numeric = high + + def __lt__(self, other: NumericRange) -> bool: + return self.low < other.low + + @overrides + def __eq__(self, other: object) -> bool: + if not isinstance(other, NumericRange): + return False + return self.low == other.low and self.high == other.high + + def overlaps_with(self, other: NumericRange) -> bool: + return self.low <= other.high and self.high >= other.low + + def __repr__(self) -> str: + return f"{self.low}..{self.high}" + + +class AugmentedIntervalTree(bst.BinarySearchTree): + def __init__(self): + super().__init__() + + @staticmethod + def assert_value_must_be_range(value: Any) -> None: + if not isinstance(value, NumericRange): + raise Exception( + "AugmentedIntervalTree expects to use NumericRanges, see bst for a " + + "general purpose tree usable for other types." + ) + + @overrides + def _on_insert(self, parent: Optional[bst.Node], new: bst.Node) -> None: + AugmentedIntervalTree.assert_value_must_be_range(new.value) + for ancestor in self.parent_path(new): + assert ancestor + if new.value.high > ancestor.value.highest_in_subtree: + ancestor.value.highest_in_subtree = new.value.high + + @overrides + def _on_delete(self, parent: Optional[bst.Node], deleted: bst.Node) -> None: + if parent: + new_highest_candidates = [] + if deleted.left: + new_highest_candidates.append(deleted.left.value.highest_in_subtree) + if deleted.right: + new_highest_candidates.append(deleted.right.value.highest_in_subtree) + if len(new_highest_candidates): + parent.value.highest_in_subtree = max( + parent.value.high, max(new_highest_candidates) + ) + else: + parent.value.highest_in_subtree = parent.value.high + + def find_overlaps(self, x: NumericRange): + """Yields ranges previously added to the tree that x overlaps with. + + >>> tree = AugmentedIntervalTree() + >>> tree.insert(NumericRange(20, 24)) + >>> tree.insert(NumericRange(18, 22)) + >>> tree.insert(NumericRange(14, 16)) + >>> tree.insert(NumericRange(1, 30)) + >>> tree.insert(NumericRange(25, 30)) + >>> tree.insert(NumericRange(29, 33)) + >>> tree.insert(NumericRange(5, 12)) + >>> tree.insert(NumericRange(1, 6)) + >>> tree.insert(NumericRange(13, 18)) + >>> tree.insert(NumericRange(16, 28)) + >>> tree.insert(NumericRange(21, 27)) + >>> for x in tree.find_overlaps(NumericRange(19, 21)): + ... print(x) + 20..24 + 18..22 + 1..30 + 16..28 + 21..27 + """ + if self.root is None: + return + yield from self._find_overlaps(self.root, x) + + def _find_overlaps(self, root: bst.Node, x: NumericRange): + """It's known that two intervals A and B overlap only + when both A.low <= B.high and A.high >= B.low. When + searching the trees for nodes overlapping with a given + interval, we can immediately skip: + + * all nodes to the right of nodes whose low value is past + the end of the given interval and + * all nodes that have their maximum high value below the + start of the given interval. + """ + if root is None: + return + + if root.value.overlaps_with(x): + yield root.value + + if root.left: + if root.left.value.highest_in_subtree >= x.low: + yield from self._find_overlaps(root.left, x) + + if root.value.low <= x.high: + if root.right: + yield from self._find_overlaps(root.right, x) + + +if __name__ == "__main__": + import doctest + + doctest.testmod()