#!/usr/bin/env python3
+# © Copyright 2021-2023, Scott Gasch
+
"""This is an augmented interval tree for storing ranges and identifying overlaps as
described by: https://en.wikipedia.org/wiki/Interval_tree.
"""
from __future__ import annotations
from functools import total_ordering
-from typing import Any, Generator, Optional, Union
+from typing import Any, Generator, Optional
from overrides import overrides
from pyutils.collectionz import bst
-
-Numeric = Union[int, float]
+from pyutils.typez.typing import Numeric
@total_ordering
-class NumericRange(object):
+class NumericRange(bst.Comparable):
"""Essentially a tuple of numbers denoting a range with some added
helper methods on it."""
"""
if low > high:
- temp: Numeric = low
- low = high
- high = temp
+ low, high = high, low
self.low: Numeric = low
self.high: Numeric = high
self.highest_in_subtree: Numeric = high
+ @overrides
def __lt__(self, other: NumericRange) -> bool:
"""
Returns:
return False
return self.low == other.low and self.high == other.high
+ @overrides
+ def __le__(self, other: object) -> bool:
+ if not isinstance(other, NumericRange):
+ return False
+ return self < other or self == other
+
def overlaps_with(self, other: NumericRange) -> bool:
"""
Returns:
return self.low <= other.high and self.high >= other.low
def __repr__(self) -> str:
- return f"{self.low}..{self.high}"
+ return f"[{self.low}..{self.high}]"
class AugmentedIntervalTree(bst.BinarySearchTree):
@staticmethod
- def _assert_value_must_be_range(value: Any) -> None:
+ def _assert_value_must_be_range(value: Any) -> NumericRange:
if not isinstance(value, NumericRange):
- raise Exception(
+ raise TypeError(
"AugmentedIntervalTree expects to use NumericRanges, see bst for a "
+ "general purpose tree usable for other types."
)
+ return value
@overrides
def _on_insert(self, parent: Optional[bst.Node], new: bst.Node) -> None:
- AugmentedIntervalTree._assert_value_must_be_range(new.value)
+ nv: NumericRange = AugmentedIntervalTree._assert_value_must_be_range(new.value)
for ancestor in self.parent_path(new):
assert ancestor
- if new.value.high > ancestor.value.highest_in_subtree:
- ancestor.value.highest_in_subtree = new.value.high
+ av: NumericRange = AugmentedIntervalTree._assert_value_must_be_range(
+ ancestor.value
+ )
+ if nv.high > av.highest_in_subtree:
+ av.highest_in_subtree = nv.high
@overrides
def _on_delete(self, parent: Optional[bst.Node], deleted: bst.Node) -> None:
if parent:
- new_highest_candidates = [parent.value.high]
+ pv: NumericRange = AugmentedIntervalTree._assert_value_must_be_range(
+ parent.value
+ )
+ new_highest_candidates = [pv.high]
if parent.left:
- new_highest_candidates.append(parent.left.value.highest_in_subtree)
+ lv: NumericRange = AugmentedIntervalTree._assert_value_must_be_range(
+ parent.left.value
+ )
+ new_highest_candidates.append(lv.highest_in_subtree)
if parent.right:
- new_highest_candidates.append(parent.right.value.highest_in_subtree)
- parent.value.highest_in_subtree = max(new_highest_candidates)
+ rv: NumericRange = AugmentedIntervalTree._assert_value_must_be_range(
+ parent.right.value
+ )
+ new_highest_candidates.append(rv.highest_in_subtree)
+ pv.highest_in_subtree = max(new_highest_candidates)
def find_one_overlap(self, to_find: NumericRange) -> Optional[NumericRange]:
"""Identify and return one overlapping node from the tree.
>>> tree.insert(NumericRange(16, 28))
>>> tree.insert(NumericRange(21, 27))
>>> tree.find_one_overlap(NumericRange(6, 7))
- 1..30
+ [1..30]
"""
return self._find_one_overlap(self.root, to_find)
if root is None:
return None
- if root.value.overlaps_with(x):
- return root.value
+ rv = AugmentedIntervalTree._assert_value_must_be_range(root.value)
+ if rv.overlaps_with(x):
+ return rv
if root.left:
- if root.left.value.highest_in_subtree >= x.low:
+ lv = AugmentedIntervalTree._assert_value_must_be_range(root.left.value)
+ if lv.highest_in_subtree >= x.low:
return self._find_one_overlap(root.left, x)
if root.right:
>>> tree.insert(NumericRange(21, 27))
>>> for x in tree.find_all_overlaps(NumericRange(19, 21)):
... print(x)
- 20..24
- 18..22
- 1..30
- 16..28
- 21..27
+ [20..24]
+ [18..22]
+ [1..30]
+ [16..28]
+ [21..27]
>>> del tree[NumericRange(1, 30)]
>>> for x in tree.find_all_overlaps(NumericRange(19, 21)):
... print(x)
- 20..24
- 18..22
- 16..28
- 21..27
+ [20..24]
+ [18..22]
+ [16..28]
+ [21..27]
"""
if self.root is None:
if root is None:
return None
- if root.value.overlaps_with(x):
- yield root.value
+ rv = AugmentedIntervalTree._assert_value_must_be_range(root.value)
+ if rv.overlaps_with(x):
+ yield rv
if root.left:
- if root.left.value.highest_in_subtree >= x.low:
+ lv = AugmentedIntervalTree._assert_value_must_be_range(root.left.value)
+ if lv.highest_in_subtree >= x.low:
yield from self._find_all_overlaps(root.left, x)
if root.right:
- if root.right.value.highest_in_subtree >= x.low:
+ rv = AugmentedIntervalTree._assert_value_must_be_range(root.right.value)
+ if rv.highest_in_subtree >= x.low:
yield from self._find_all_overlaps(root.right, x)