X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=src%2Fpyutils%2Fcollectionz%2Finterval_tree.py;h=c4e4e9ae75beaa9fd85260675c7b1001df6c753e;hb=HEAD;hp=a8278a2dc8ea835a501951e3abddb9727d405930;hpb=5fd1f68a58ef3bbf37528b9397a6950d7af661ab;p=pyutils.git diff --git a/src/pyutils/collectionz/interval_tree.py b/src/pyutils/collectionz/interval_tree.py index a8278a2..c4e4e9a 100644 --- a/src/pyutils/collectionz/interval_tree.py +++ b/src/pyutils/collectionz/interval_tree.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 +# © Copyright 2021-2023, Scott Gasch + """This is an augmented interval tree for storing ranges and identifying overlaps as described by: https://en.wikipedia.org/wiki/Interval_tree. """ @@ -7,17 +9,16 @@ described by: https://en.wikipedia.org/wiki/Interval_tree. from __future__ import annotations from functools import total_ordering -from typing import Any, Generator, Optional, Union +from typing import Any, Generator, Optional from overrides import overrides from pyutils.collectionz import bst - -Numeric = Union[int, float] +from pyutils.typez.typing import Numeric @total_ordering -class NumericRange(object): +class NumericRange(bst.Comparable): """Essentially a tuple of numbers denoting a range with some added helper methods on it.""" @@ -35,13 +36,12 @@ class NumericRange(object): """ if low > high: - temp: Numeric = low - low = high - high = temp + low, high = high, low self.low: Numeric = low self.high: Numeric = high self.highest_in_subtree: Numeric = high + @overrides def __lt__(self, other: NumericRange) -> bool: """ Returns: @@ -62,6 +62,12 @@ class NumericRange(object): return False return self.low == other.low and self.high == other.high + @overrides + def __le__(self, other: object) -> bool: + if not isinstance(other, NumericRange): + return False + return self < other or self == other + def overlaps_with(self, other: NumericRange) -> bool: """ Returns: @@ -70,35 +76,48 @@ class NumericRange(object): return self.low <= other.high and self.high >= other.low def __repr__(self) -> str: - return f"{self.low}..{self.high}" + return f"[{self.low}..{self.high}]" class AugmentedIntervalTree(bst.BinarySearchTree): @staticmethod - def _assert_value_must_be_range(value: Any) -> None: + def _assert_value_must_be_range(value: Any) -> NumericRange: if not isinstance(value, NumericRange): - raise Exception( + raise TypeError( "AugmentedIntervalTree expects to use NumericRanges, see bst for a " + "general purpose tree usable for other types." ) + return value @overrides def _on_insert(self, parent: Optional[bst.Node], new: bst.Node) -> None: - AugmentedIntervalTree._assert_value_must_be_range(new.value) + nv: NumericRange = AugmentedIntervalTree._assert_value_must_be_range(new.value) for ancestor in self.parent_path(new): assert ancestor - if new.value.high > ancestor.value.highest_in_subtree: - ancestor.value.highest_in_subtree = new.value.high + av: NumericRange = AugmentedIntervalTree._assert_value_must_be_range( + ancestor.value + ) + if nv.high > av.highest_in_subtree: + av.highest_in_subtree = nv.high @overrides def _on_delete(self, parent: Optional[bst.Node], deleted: bst.Node) -> None: if parent: - new_highest_candidates = [parent.value.high] + pv: NumericRange = AugmentedIntervalTree._assert_value_must_be_range( + parent.value + ) + new_highest_candidates = [pv.high] if parent.left: - new_highest_candidates.append(parent.left.value.highest_in_subtree) + lv: NumericRange = AugmentedIntervalTree._assert_value_must_be_range( + parent.left.value + ) + new_highest_candidates.append(lv.highest_in_subtree) if parent.right: - new_highest_candidates.append(parent.right.value.highest_in_subtree) - parent.value.highest_in_subtree = max(new_highest_candidates) + rv: NumericRange = AugmentedIntervalTree._assert_value_must_be_range( + parent.right.value + ) + new_highest_candidates.append(rv.highest_in_subtree) + pv.highest_in_subtree = max(new_highest_candidates) def find_one_overlap(self, to_find: NumericRange) -> Optional[NumericRange]: """Identify and return one overlapping node from the tree. @@ -123,7 +142,7 @@ class AugmentedIntervalTree(bst.BinarySearchTree): >>> tree.insert(NumericRange(16, 28)) >>> tree.insert(NumericRange(21, 27)) >>> tree.find_one_overlap(NumericRange(6, 7)) - 1..30 + [1..30] """ return self._find_one_overlap(self.root, to_find) @@ -134,11 +153,13 @@ class AugmentedIntervalTree(bst.BinarySearchTree): if root is None: return None - if root.value.overlaps_with(x): - return root.value + rv = AugmentedIntervalTree._assert_value_must_be_range(root.value) + if rv.overlaps_with(x): + return rv if root.left: - if root.left.value.highest_in_subtree >= x.low: + lv = AugmentedIntervalTree._assert_value_must_be_range(root.left.value) + if lv.highest_in_subtree >= x.low: return self._find_one_overlap(root.left, x) if root.right: @@ -172,19 +193,19 @@ class AugmentedIntervalTree(bst.BinarySearchTree): >>> tree.insert(NumericRange(21, 27)) >>> for x in tree.find_all_overlaps(NumericRange(19, 21)): ... print(x) - 20..24 - 18..22 - 1..30 - 16..28 - 21..27 + [20..24] + [18..22] + [1..30] + [16..28] + [21..27] >>> del tree[NumericRange(1, 30)] >>> for x in tree.find_all_overlaps(NumericRange(19, 21)): ... print(x) - 20..24 - 18..22 - 16..28 - 21..27 + [20..24] + [18..22] + [16..28] + [21..27] """ if self.root is None: @@ -197,15 +218,18 @@ class AugmentedIntervalTree(bst.BinarySearchTree): if root is None: return None - if root.value.overlaps_with(x): - yield root.value + rv = AugmentedIntervalTree._assert_value_must_be_range(root.value) + if rv.overlaps_with(x): + yield rv if root.left: - if root.left.value.highest_in_subtree >= x.low: + lv = AugmentedIntervalTree._assert_value_must_be_range(root.left.value) + if lv.highest_in_subtree >= x.low: yield from self._find_all_overlaps(root.left, x) if root.right: - if root.right.value.highest_in_subtree >= x.low: + rv = AugmentedIntervalTree._assert_value_must_be_range(root.right.value) + if rv.highest_in_subtree >= x.low: yield from self._find_all_overlaps(root.right, x)