#!/usr/bin/env python3
+# © Copyright 2021-2022, Scott Gasch
+
"""This is an augmented interval tree for storing ranges and identifying overlaps as
described by: https://en.wikipedia.org/wiki/Interval_tree.
"""
from __future__ import annotations
from functools import total_ordering
-from typing import Any, Optional, Union
+from typing import Any, Generator, Optional
from overrides import overrides
from pyutils.collectionz import bst
-
-Numeric = Union[int, float]
+from pyutils.types.simple import Numeric
@total_ordering
class NumericRange(object):
+ """Essentially a tuple of numbers denoting a range with some added
+ helper methods on it."""
+
def __init__(self, low: Numeric, high: Numeric):
+ """Creates a NumericRange.
+
+ Args:
+ low: the lowest point in the range (inclusive).
+ high: the highest point in the range (inclusive).
+
+ .. warning::
+
+ If low > high this code swaps the parameters and keeps the range
+ rather than raising.
+
+ """
if low > high:
temp: Numeric = low
low = high
self.highest_in_subtree: Numeric = high
def __lt__(self, other: NumericRange) -> bool:
- return self.low < other.low
+ """
+ Returns:
+ True is this range is less than (lower low) other, else False.
+ """
+ if self.low != other.low:
+ return self.low < other.low
+ else:
+ return self.high < other.high
@overrides
def __eq__(self, other: object) -> bool:
+ """
+ Returns:
+ True if this is the same range as other, else False.
+ """
if not isinstance(other, NumericRange):
return False
return self.low == other.low and self.high == other.high
def overlaps_with(self, other: NumericRange) -> bool:
+ """
+ Returns:
+ True if this NumericRange overlaps with other, else False.
+ """
return self.low <= other.high and self.high >= other.low
def __repr__(self) -> str:
class AugmentedIntervalTree(bst.BinarySearchTree):
- def __init__(self):
- super().__init__()
-
@staticmethod
- def assert_value_must_be_range(value: Any) -> None:
+ def _assert_value_must_be_range(value: Any) -> None:
if not isinstance(value, NumericRange):
raise Exception(
"AugmentedIntervalTree expects to use NumericRanges, see bst for a "
@overrides
def _on_insert(self, parent: Optional[bst.Node], new: bst.Node) -> None:
- AugmentedIntervalTree.assert_value_must_be_range(new.value)
+ AugmentedIntervalTree._assert_value_must_be_range(new.value)
for ancestor in self.parent_path(new):
assert ancestor
if new.value.high > ancestor.value.highest_in_subtree:
new_highest_candidates.append(parent.right.value.highest_in_subtree)
parent.value.highest_in_subtree = max(new_highest_candidates)
- def find_one_overlap(self, x: NumericRange):
+ def find_one_overlap(self, to_find: NumericRange) -> Optional[NumericRange]:
"""Identify and return one overlapping node from the tree.
+ Args:
+ to_find: the interval with which to find an overlap.
+
+ Returns:
+ An overlapping range from the tree or None if no such
+ ranges exist in the tree at present.
+
>>> tree = AugmentedIntervalTree()
>>> tree.insert(NumericRange(20, 24))
>>> tree.insert(NumericRange(18, 22))
>>> tree.insert(NumericRange(21, 27))
>>> tree.find_one_overlap(NumericRange(6, 7))
1..30
+
"""
- return self._find_one_overlap(self.root, x)
+ return self._find_one_overlap(self.root, to_find)
- def _find_one_overlap(self, root: bst.Node, x: NumericRange):
+ def _find_one_overlap(
+ self, root: bst.Node, x: NumericRange
+ ) -> Optional[NumericRange]:
if root is None:
- return
+ return None
if root.value.overlaps_with(x):
return root.value
if root.right:
return self._find_one_overlap(root.right, x)
+ return None
+
+ def find_all_overlaps(
+ self, to_find: NumericRange
+ ) -> Generator[NumericRange, None, None]:
+ """Yields ranges previously added to the tree that overlaps with
+ to_find argument.
+
+ Args:
+ to_find: the interval with which to find all overlaps.
- def find_all_overlaps(self, x: NumericRange):
- """Yields ranges previously added to the tree that x overlaps with.
+ Returns:
+ A (potentially empty) sequence of all ranges in the tree
+ that overlap with the argument.
>>> tree = AugmentedIntervalTree()
>>> tree.insert(NumericRange(20, 24))
18..22
16..28
21..27
+
"""
if self.root is None:
return
- yield from self._find_all_overlaps(self.root, x)
+ yield from self._find_all_overlaps(self.root, to_find)
- def _find_all_overlaps(self, root: bst.Node, x: NumericRange):
+ def _find_all_overlaps(
+ self, root: bst.Node, x: NumericRange
+ ) -> Generator[NumericRange, None, None]:
if root is None:
- return
+ return None
if root.value.overlaps_with(x):
yield root.value