"""A binary search tree implementation."""
-from abc import ABCMeta, abstractmethod
-from typing import Any, Generator, List, Optional, TypeVar
+from abc import abstractmethod
+from typing import Any, Generator, List, Optional, Protocol
-class Comparable(metaclass=ABCMeta):
+class Comparable(Protocol):
@abstractmethod
def __lt__(self, other: Any) -> bool:
pass
pass
-ComparableNodeValue = TypeVar('ComparableNodeValue', bound=Comparable)
-
-
-class Node(object):
- def __init__(self, value: ComparableNodeValue) -> None:
+class Node:
+ def __init__(self, value: Comparable) -> None:
"""A BST node. Note that value can be anything as long as it
is comparable with other instances of itself. Check out
:meth:`functools.total_ordering`
"""
self.left: Optional[Node] = None
self.right: Optional[Node] = None
- self.value: ComparableNodeValue = value
+ self.value: Comparable = value
class BinarySearchTree(object):
"""This is called immediately _after_ a new node is inserted."""
pass
- def insert(self, value: ComparableNodeValue) -> None:
+ def insert(self, value: Comparable) -> None:
"""
Insert something into the tree.
else:
self._insert(value, self.root)
- def _insert(self, value: ComparableNodeValue, node: Node):
+ def _insert(self, value: Comparable, node: Node):
"""Insertion helper"""
if value < node.value:
if node.left is not None:
self.count += 1
self._on_insert(node, node.right)
- def __getitem__(self, value: ComparableNodeValue) -> Optional[Node]:
+ def __getitem__(self, value: Comparable) -> Optional[Node]:
"""
Find an item in the tree and return its Node. Returns
None if the item is not in the tree.
return self._find_exact(value, self.root)
return None
- def _find_exact(self, target: ComparableNodeValue, node: Node) -> Optional[Node]:
+ def _find_exact(self, target: Comparable, node: Node) -> Optional[Node]:
"""Recursively traverse the tree looking for a node with the
target value. Return that node if it exists, otherwise return
None."""
return None
def _find_lowest_node_less_than_or_equal_to(
- self, target: ComparableNodeValue, node: Optional[Node]
+ self, target: Comparable, node: Optional[Node]
) -> Optional[Node]:
"""Find helper that returns the lowest node that is less
than or equal to the target value. Returns None if target is
return self._find_lowest_node_less_than_or_equal_to(target, node.left)
def _find_lowest_node_greater_than_or_equal_to(
- self, target: ComparableNodeValue, node: Optional[Node]
+ self, target: Comparable, node: Optional[Node]
) -> Optional[Node]:
"""Find helper that returns the lowest node that is greater
than or equal to the target value. Returns None if target is
"""
return self._parent_path(self.root, node)
- def __delitem__(self, value: ComparableNodeValue) -> bool:
+ def __delitem__(self, value: Comparable) -> bool:
"""
Delete an item from the tree and preserve the BST property.
"""This is called just after deleted was deleted from the tree"""
pass
- def _delete(
- self, value: ComparableNodeValue, parent: Optional[Node], node: Node
- ) -> bool:
+ def _delete(self, value: Comparable, parent: Optional[Node], node: Node) -> bool:
"""Delete helper"""
if node.value == value:
"""
return self.count
- def __contains__(self, value: ComparableNodeValue) -> bool:
+ def __contains__(self, value: Comparable) -> bool:
"""
Returns:
True if the item is in the tree; False otherwise.
node = ancestor
return None
- def get_nodes_in_range_inclusive(
- self, lower: ComparableNodeValue, upper: ComparableNodeValue
- ):
+ def get_nodes_in_range_inclusive(self, lower: Comparable, upper: Comparable):
"""
>>> t = BinarySearchTree()
>>> t.insert(50)
@total_ordering
-class NumericRange(object):
+class NumericRange(bst.Comparable):
"""Essentially a tuple of numbers denoting a range with some added
helper methods on it."""
"""
if low > high:
- temp: Numeric = low
- low = high
- high = temp
+ low, high = high, low
self.low: Numeric = low
self.high: Numeric = high
self.highest_in_subtree: Numeric = high
+ @overrides
def __lt__(self, other: NumericRange) -> bool:
"""
Returns:
return False
return self.low == other.low and self.high == other.high
+ @overrides
+ def __le__(self, other: object) -> bool:
+ if not isinstance(other, NumericRange):
+ return False
+ return self < other or self == other
+
def overlaps_with(self, other: NumericRange) -> bool:
"""
Returns:
return self.low <= other.high and self.high >= other.low
def __repr__(self) -> str:
- return f"{self.low}..{self.high}"
+ return f"[{self.low}..{self.high}]"
class AugmentedIntervalTree(bst.BinarySearchTree):
@staticmethod
- def _assert_value_must_be_range(value: Any) -> None:
+ def _assert_value_must_be_range(value: Any) -> NumericRange:
if not isinstance(value, NumericRange):
raise Exception(
"AugmentedIntervalTree expects to use NumericRanges, see bst for a "
+ "general purpose tree usable for other types."
)
+ return value
@overrides
def _on_insert(self, parent: Optional[bst.Node], new: bst.Node) -> None:
- AugmentedIntervalTree._assert_value_must_be_range(new.value)
+ nv: NumericRange = AugmentedIntervalTree._assert_value_must_be_range(new.value)
for ancestor in self.parent_path(new):
assert ancestor
- if new.value.high > ancestor.value.highest_in_subtree:
- ancestor.value.highest_in_subtree = new.value.high
+ av: NumericRange = AugmentedIntervalTree._assert_value_must_be_range(
+ ancestor.value
+ )
+ if nv.high > av.highest_in_subtree:
+ av.highest_in_subtree = nv.high
@overrides
def _on_delete(self, parent: Optional[bst.Node], deleted: bst.Node) -> None:
if parent:
- new_highest_candidates = [parent.value.high]
+ pv: NumericRange = AugmentedIntervalTree._assert_value_must_be_range(
+ parent.value
+ )
+ new_highest_candidates = [pv.high]
if parent.left:
- new_highest_candidates.append(parent.left.value.highest_in_subtree)
+ lv: NumericRange = AugmentedIntervalTree._assert_value_must_be_range(
+ parent.left.value
+ )
+ new_highest_candidates.append(lv.highest_in_subtree)
if parent.right:
- new_highest_candidates.append(parent.right.value.highest_in_subtree)
- parent.value.highest_in_subtree = max(new_highest_candidates)
+ rv: NumericRange = AugmentedIntervalTree._assert_value_must_be_range(
+ parent.right.value
+ )
+ new_highest_candidates.append(rv.highest_in_subtree)
+ pv.highest_in_subtree = max(new_highest_candidates)
def find_one_overlap(self, to_find: NumericRange) -> Optional[NumericRange]:
"""Identify and return one overlapping node from the tree.
>>> tree.insert(NumericRange(16, 28))
>>> tree.insert(NumericRange(21, 27))
>>> tree.find_one_overlap(NumericRange(6, 7))
- 1..30
+ [1..30]
"""
return self._find_one_overlap(self.root, to_find)
if root is None:
return None
- if root.value.overlaps_with(x):
- return root.value
+ rv = AugmentedIntervalTree._assert_value_must_be_range(root.value)
+ if rv.overlaps_with(x):
+ return rv
if root.left:
- if root.left.value.highest_in_subtree >= x.low:
+ lv = AugmentedIntervalTree._assert_value_must_be_range(root.left.value)
+ if lv.highest_in_subtree >= x.low:
return self._find_one_overlap(root.left, x)
if root.right:
>>> tree.insert(NumericRange(21, 27))
>>> for x in tree.find_all_overlaps(NumericRange(19, 21)):
... print(x)
- 20..24
- 18..22
- 1..30
- 16..28
- 21..27
+ [20..24]
+ [18..22]
+ [1..30]
+ [16..28]
+ [21..27]
>>> del tree[NumericRange(1, 30)]
>>> for x in tree.find_all_overlaps(NumericRange(19, 21)):
... print(x)
- 20..24
- 18..22
- 16..28
- 21..27
+ [20..24]
+ [18..22]
+ [16..28]
+ [21..27]
"""
if self.root is None:
if root is None:
return None
- if root.value.overlaps_with(x):
- yield root.value
+ rv = AugmentedIntervalTree._assert_value_must_be_range(root.value)
+ if rv.overlaps_with(x):
+ yield rv
if root.left:
- if root.left.value.highest_in_subtree >= x.low:
+ lv = AugmentedIntervalTree._assert_value_must_be_range(root.left.value)
+ if lv.highest_in_subtree >= x.low:
yield from self._find_all_overlaps(root.left, x)
if root.right:
- if root.right.value.highest_in_subtree >= x.low:
+ rv = AugmentedIntervalTree._assert_value_must_be_range(root.right.value)
+ if rv.highest_in_subtree >= x.low:
yield from self._find_all_overlaps(root.right, x)