projects
/
pyutils.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Adds a __repr__ to graph.
[pyutils.git]
/
src
/
pyutils
/
collectionz
/
interval_tree.py
diff --git
a/src/pyutils/collectionz/interval_tree.py
b/src/pyutils/collectionz/interval_tree.py
index a8278a2dc8ea835a501951e3abddb9727d405930..c4e4e9ae75beaa9fd85260675c7b1001df6c753e 100644
(file)
--- a/
src/pyutils/collectionz/interval_tree.py
+++ b/
src/pyutils/collectionz/interval_tree.py
@@
-1,5
+1,7
@@
#!/usr/bin/env python3
#!/usr/bin/env python3
+# © Copyright 2021-2023, Scott Gasch
+
"""This is an augmented interval tree for storing ranges and identifying overlaps as
described by: https://en.wikipedia.org/wiki/Interval_tree.
"""
"""This is an augmented interval tree for storing ranges and identifying overlaps as
described by: https://en.wikipedia.org/wiki/Interval_tree.
"""
@@
-7,17
+9,16
@@
described by: https://en.wikipedia.org/wiki/Interval_tree.
from __future__ import annotations
from functools import total_ordering
from __future__ import annotations
from functools import total_ordering
-from typing import Any, Generator, Optional
, Union
+from typing import Any, Generator, Optional
from overrides import overrides
from pyutils.collectionz import bst
from overrides import overrides
from pyutils.collectionz import bst
-
-Numeric = Union[int, float]
+from pyutils.typez.typing import Numeric
@total_ordering
@total_ordering
-class NumericRange(
object
):
+class NumericRange(
bst.Comparable
):
"""Essentially a tuple of numbers denoting a range with some added
helper methods on it."""
"""Essentially a tuple of numbers denoting a range with some added
helper methods on it."""
@@
-35,13
+36,12
@@
class NumericRange(object):
"""
if low > high:
"""
if low > high:
- temp: Numeric = low
- low = high
- high = temp
+ low, high = high, low
self.low: Numeric = low
self.high: Numeric = high
self.highest_in_subtree: Numeric = high
self.low: Numeric = low
self.high: Numeric = high
self.highest_in_subtree: Numeric = high
+ @overrides
def __lt__(self, other: NumericRange) -> bool:
"""
Returns:
def __lt__(self, other: NumericRange) -> bool:
"""
Returns:
@@
-62,6
+62,12
@@
class NumericRange(object):
return False
return self.low == other.low and self.high == other.high
return False
return self.low == other.low and self.high == other.high
+ @overrides
+ def __le__(self, other: object) -> bool:
+ if not isinstance(other, NumericRange):
+ return False
+ return self < other or self == other
+
def overlaps_with(self, other: NumericRange) -> bool:
"""
Returns:
def overlaps_with(self, other: NumericRange) -> bool:
"""
Returns:
@@
-70,35
+76,48
@@
class NumericRange(object):
return self.low <= other.high and self.high >= other.low
def __repr__(self) -> str:
return self.low <= other.high and self.high >= other.low
def __repr__(self) -> str:
- return f"
{self.low}..{self.high}
"
+ return f"
[{self.low}..{self.high}]
"
class AugmentedIntervalTree(bst.BinarySearchTree):
@staticmethod
class AugmentedIntervalTree(bst.BinarySearchTree):
@staticmethod
- def _assert_value_must_be_range(value: Any) -> N
on
e:
+ def _assert_value_must_be_range(value: Any) -> N
umericRang
e:
if not isinstance(value, NumericRange):
if not isinstance(value, NumericRange):
- raise
Exception
(
+ raise
TypeError
(
"AugmentedIntervalTree expects to use NumericRanges, see bst for a "
+ "general purpose tree usable for other types."
)
"AugmentedIntervalTree expects to use NumericRanges, see bst for a "
+ "general purpose tree usable for other types."
)
+ return value
@overrides
def _on_insert(self, parent: Optional[bst.Node], new: bst.Node) -> None:
@overrides
def _on_insert(self, parent: Optional[bst.Node], new: bst.Node) -> None:
- AugmentedIntervalTree._assert_value_must_be_range(new.value)
+
nv: NumericRange =
AugmentedIntervalTree._assert_value_must_be_range(new.value)
for ancestor in self.parent_path(new):
assert ancestor
for ancestor in self.parent_path(new):
assert ancestor
- if new.value.high > ancestor.value.highest_in_subtree:
- ancestor.value.highest_in_subtree = new.value.high
+ av: NumericRange = AugmentedIntervalTree._assert_value_must_be_range(
+ ancestor.value
+ )
+ if nv.high > av.highest_in_subtree:
+ av.highest_in_subtree = nv.high
@overrides
def _on_delete(self, parent: Optional[bst.Node], deleted: bst.Node) -> None:
if parent:
@overrides
def _on_delete(self, parent: Optional[bst.Node], deleted: bst.Node) -> None:
if parent:
- new_highest_candidates = [parent.value.high]
+ pv: NumericRange = AugmentedIntervalTree._assert_value_must_be_range(
+ parent.value
+ )
+ new_highest_candidates = [pv.high]
if parent.left:
if parent.left:
- new_highest_candidates.append(parent.left.value.highest_in_subtree)
+ lv: NumericRange = AugmentedIntervalTree._assert_value_must_be_range(
+ parent.left.value
+ )
+ new_highest_candidates.append(lv.highest_in_subtree)
if parent.right:
if parent.right:
- new_highest_candidates.append(parent.right.value.highest_in_subtree)
- parent.value.highest_in_subtree = max(new_highest_candidates)
+ rv: NumericRange = AugmentedIntervalTree._assert_value_must_be_range(
+ parent.right.value
+ )
+ new_highest_candidates.append(rv.highest_in_subtree)
+ pv.highest_in_subtree = max(new_highest_candidates)
def find_one_overlap(self, to_find: NumericRange) -> Optional[NumericRange]:
"""Identify and return one overlapping node from the tree.
def find_one_overlap(self, to_find: NumericRange) -> Optional[NumericRange]:
"""Identify and return one overlapping node from the tree.
@@
-123,7
+142,7
@@
class AugmentedIntervalTree(bst.BinarySearchTree):
>>> tree.insert(NumericRange(16, 28))
>>> tree.insert(NumericRange(21, 27))
>>> tree.find_one_overlap(NumericRange(6, 7))
>>> tree.insert(NumericRange(16, 28))
>>> tree.insert(NumericRange(21, 27))
>>> tree.find_one_overlap(NumericRange(6, 7))
- 1..30
+ [1..30]
"""
return self._find_one_overlap(self.root, to_find)
"""
return self._find_one_overlap(self.root, to_find)
@@
-134,11
+153,13
@@
class AugmentedIntervalTree(bst.BinarySearchTree):
if root is None:
return None
if root is None:
return None
- if root.value.overlaps_with(x):
- return root.value
+ rv = AugmentedIntervalTree._assert_value_must_be_range(root.value)
+ if rv.overlaps_with(x):
+ return rv
if root.left:
if root.left:
- if root.left.value.highest_in_subtree >= x.low:
+ lv = AugmentedIntervalTree._assert_value_must_be_range(root.left.value)
+ if lv.highest_in_subtree >= x.low:
return self._find_one_overlap(root.left, x)
if root.right:
return self._find_one_overlap(root.left, x)
if root.right:
@@
-172,19
+193,19
@@
class AugmentedIntervalTree(bst.BinarySearchTree):
>>> tree.insert(NumericRange(21, 27))
>>> for x in tree.find_all_overlaps(NumericRange(19, 21)):
... print(x)
>>> tree.insert(NumericRange(21, 27))
>>> for x in tree.find_all_overlaps(NumericRange(19, 21)):
... print(x)
- 20..24
- 18..22
- 1..30
- 16..28
- 21..27
+ [20..24]
+ [18..22]
+ [1..30]
+ [16..28]
+ [21..27]
>>> del tree[NumericRange(1, 30)]
>>> for x in tree.find_all_overlaps(NumericRange(19, 21)):
... print(x)
>>> del tree[NumericRange(1, 30)]
>>> for x in tree.find_all_overlaps(NumericRange(19, 21)):
... print(x)
- 20..24
- 18..22
- 16..28
- 21..27
+ [20..24]
+ [18..22]
+ [16..28]
+ [21..27]
"""
if self.root is None:
"""
if self.root is None:
@@
-197,15
+218,18
@@
class AugmentedIntervalTree(bst.BinarySearchTree):
if root is None:
return None
if root is None:
return None
- if root.value.overlaps_with(x):
- yield root.value
+ rv = AugmentedIntervalTree._assert_value_must_be_range(root.value)
+ if rv.overlaps_with(x):
+ yield rv
if root.left:
if root.left:
- if root.left.value.highest_in_subtree >= x.low:
+ lv = AugmentedIntervalTree._assert_value_must_be_range(root.left.value)
+ if lv.highest_in_subtree >= x.low:
yield from self._find_all_overlaps(root.left, x)
if root.right:
yield from self._find_all_overlaps(root.left, x)
if root.right:
- if root.right.value.highest_in_subtree >= x.low:
+ rv = AugmentedIntervalTree._assert_value_must_be_range(root.right.value)
+ if rv.highest_in_subtree >= x.low:
yield from self._find_all_overlaps(root.right, x)
yield from self._find_all_overlaps(root.right, x)