From 60cf08018e8ddfee419138b2a39e4e4ef98c743a Mon Sep 17 00:00:00 2001 From: Scott Gasch Date: Sun, 7 May 2023 19:43:26 -0700 Subject: [PATCH] Use Protocol to implement the interface typevar here instead. --- src/pyutils/collectionz/bst.py | 39 +++++------ src/pyutils/collectionz/interval_tree.py | 83 +++++++++++++++--------- 2 files changed, 69 insertions(+), 53 deletions(-) diff --git a/src/pyutils/collectionz/bst.py b/src/pyutils/collectionz/bst.py index cefbf59..246b605 100644 --- a/src/pyutils/collectionz/bst.py +++ b/src/pyutils/collectionz/bst.py @@ -4,11 +4,11 @@ """A binary search tree implementation.""" -from abc import ABCMeta, abstractmethod -from typing import Any, Generator, List, Optional, TypeVar +from abc import abstractmethod +from typing import Any, Generator, List, Optional, Protocol -class Comparable(metaclass=ABCMeta): +class Comparable(Protocol): @abstractmethod def __lt__(self, other: Any) -> bool: pass @@ -22,11 +22,8 @@ class Comparable(metaclass=ABCMeta): pass -ComparableNodeValue = TypeVar('ComparableNodeValue', bound=Comparable) - - -class Node(object): - def __init__(self, value: ComparableNodeValue) -> None: +class Node: + def __init__(self, value: Comparable) -> None: """A BST node. Note that value can be anything as long as it is comparable with other instances of itself. Check out :meth:`functools.total_ordering` @@ -38,7 +35,7 @@ class Node(object): """ self.left: Optional[Node] = None self.right: Optional[Node] = None - self.value: ComparableNodeValue = value + self.value: Comparable = value class BinarySearchTree(object): @@ -59,7 +56,7 @@ class BinarySearchTree(object): """This is called immediately _after_ a new node is inserted.""" pass - def insert(self, value: ComparableNodeValue) -> None: + def insert(self, value: Comparable) -> None: """ Insert something into the tree. @@ -84,7 +81,7 @@ class BinarySearchTree(object): else: self._insert(value, self.root) - def _insert(self, value: ComparableNodeValue, node: Node): + def _insert(self, value: Comparable, node: Node): """Insertion helper""" if value < node.value: if node.left is not None: @@ -101,7 +98,7 @@ class BinarySearchTree(object): self.count += 1 self._on_insert(node, node.right) - def __getitem__(self, value: ComparableNodeValue) -> Optional[Node]: + def __getitem__(self, value: Comparable) -> Optional[Node]: """ Find an item in the tree and return its Node. Returns None if the item is not in the tree. @@ -122,7 +119,7 @@ class BinarySearchTree(object): return self._find_exact(value, self.root) return None - def _find_exact(self, target: ComparableNodeValue, node: Node) -> Optional[Node]: + def _find_exact(self, target: Comparable, node: Node) -> Optional[Node]: """Recursively traverse the tree looking for a node with the target value. Return that node if it exists, otherwise return None.""" @@ -136,7 +133,7 @@ class BinarySearchTree(object): return None def _find_lowest_node_less_than_or_equal_to( - self, target: ComparableNodeValue, node: Optional[Node] + self, target: Comparable, node: Optional[Node] ) -> Optional[Node]: """Find helper that returns the lowest node that is less than or equal to the target value. Returns None if target is @@ -196,7 +193,7 @@ class BinarySearchTree(object): return self._find_lowest_node_less_than_or_equal_to(target, node.left) def _find_lowest_node_greater_than_or_equal_to( - self, target: ComparableNodeValue, node: Optional[Node] + self, target: Comparable, node: Optional[Node] ) -> Optional[Node]: """Find helper that returns the lowest node that is greater than or equal to the target value. Returns None if target is @@ -329,7 +326,7 @@ class BinarySearchTree(object): """ return self._parent_path(self.root, node) - def __delitem__(self, value: ComparableNodeValue) -> bool: + def __delitem__(self, value: Comparable) -> bool: """ Delete an item from the tree and preserve the BST property. @@ -419,9 +416,7 @@ class BinarySearchTree(object): """This is called just after deleted was deleted from the tree""" pass - def _delete( - self, value: ComparableNodeValue, parent: Optional[Node], node: Node - ) -> bool: + def _delete(self, value: Comparable, parent: Optional[Node], node: Node) -> bool: """Delete helper""" if node.value == value: @@ -501,7 +496,7 @@ class BinarySearchTree(object): """ return self.count - def __contains__(self, value: ComparableNodeValue) -> bool: + def __contains__(self, value: Comparable) -> bool: """ Returns: True if the item is in the tree; False otherwise. @@ -748,9 +743,7 @@ class BinarySearchTree(object): node = ancestor return None - def get_nodes_in_range_inclusive( - self, lower: ComparableNodeValue, upper: ComparableNodeValue - ): + def get_nodes_in_range_inclusive(self, lower: Comparable, upper: Comparable): """ >>> t = BinarySearchTree() >>> t.insert(50) diff --git a/src/pyutils/collectionz/interval_tree.py b/src/pyutils/collectionz/interval_tree.py index 0a88a3d..9542e21 100644 --- a/src/pyutils/collectionz/interval_tree.py +++ b/src/pyutils/collectionz/interval_tree.py @@ -18,7 +18,7 @@ from pyutils.typez.simple import Numeric @total_ordering -class NumericRange(object): +class NumericRange(bst.Comparable): """Essentially a tuple of numbers denoting a range with some added helper methods on it.""" @@ -36,13 +36,12 @@ class NumericRange(object): """ if low > high: - temp: Numeric = low - low = high - high = temp + low, high = high, low self.low: Numeric = low self.high: Numeric = high self.highest_in_subtree: Numeric = high + @overrides def __lt__(self, other: NumericRange) -> bool: """ Returns: @@ -63,6 +62,12 @@ class NumericRange(object): return False return self.low == other.low and self.high == other.high + @overrides + def __le__(self, other: object) -> bool: + if not isinstance(other, NumericRange): + return False + return self < other or self == other + def overlaps_with(self, other: NumericRange) -> bool: """ Returns: @@ -71,35 +76,48 @@ class NumericRange(object): return self.low <= other.high and self.high >= other.low def __repr__(self) -> str: - return f"{self.low}..{self.high}" + return f"[{self.low}..{self.high}]" class AugmentedIntervalTree(bst.BinarySearchTree): @staticmethod - def _assert_value_must_be_range(value: Any) -> None: + def _assert_value_must_be_range(value: Any) -> NumericRange: if not isinstance(value, NumericRange): raise Exception( "AugmentedIntervalTree expects to use NumericRanges, see bst for a " + "general purpose tree usable for other types." ) + return value @overrides def _on_insert(self, parent: Optional[bst.Node], new: bst.Node) -> None: - AugmentedIntervalTree._assert_value_must_be_range(new.value) + nv: NumericRange = AugmentedIntervalTree._assert_value_must_be_range(new.value) for ancestor in self.parent_path(new): assert ancestor - if new.value.high > ancestor.value.highest_in_subtree: - ancestor.value.highest_in_subtree = new.value.high + av: NumericRange = AugmentedIntervalTree._assert_value_must_be_range( + ancestor.value + ) + if nv.high > av.highest_in_subtree: + av.highest_in_subtree = nv.high @overrides def _on_delete(self, parent: Optional[bst.Node], deleted: bst.Node) -> None: if parent: - new_highest_candidates = [parent.value.high] + pv: NumericRange = AugmentedIntervalTree._assert_value_must_be_range( + parent.value + ) + new_highest_candidates = [pv.high] if parent.left: - new_highest_candidates.append(parent.left.value.highest_in_subtree) + lv: NumericRange = AugmentedIntervalTree._assert_value_must_be_range( + parent.left.value + ) + new_highest_candidates.append(lv.highest_in_subtree) if parent.right: - new_highest_candidates.append(parent.right.value.highest_in_subtree) - parent.value.highest_in_subtree = max(new_highest_candidates) + rv: NumericRange = AugmentedIntervalTree._assert_value_must_be_range( + parent.right.value + ) + new_highest_candidates.append(rv.highest_in_subtree) + pv.highest_in_subtree = max(new_highest_candidates) def find_one_overlap(self, to_find: NumericRange) -> Optional[NumericRange]: """Identify and return one overlapping node from the tree. @@ -124,7 +142,7 @@ class AugmentedIntervalTree(bst.BinarySearchTree): >>> tree.insert(NumericRange(16, 28)) >>> tree.insert(NumericRange(21, 27)) >>> tree.find_one_overlap(NumericRange(6, 7)) - 1..30 + [1..30] """ return self._find_one_overlap(self.root, to_find) @@ -135,11 +153,13 @@ class AugmentedIntervalTree(bst.BinarySearchTree): if root is None: return None - if root.value.overlaps_with(x): - return root.value + rv = AugmentedIntervalTree._assert_value_must_be_range(root.value) + if rv.overlaps_with(x): + return rv if root.left: - if root.left.value.highest_in_subtree >= x.low: + lv = AugmentedIntervalTree._assert_value_must_be_range(root.left.value) + if lv.highest_in_subtree >= x.low: return self._find_one_overlap(root.left, x) if root.right: @@ -173,19 +193,19 @@ class AugmentedIntervalTree(bst.BinarySearchTree): >>> tree.insert(NumericRange(21, 27)) >>> for x in tree.find_all_overlaps(NumericRange(19, 21)): ... print(x) - 20..24 - 18..22 - 1..30 - 16..28 - 21..27 + [20..24] + [18..22] + [1..30] + [16..28] + [21..27] >>> del tree[NumericRange(1, 30)] >>> for x in tree.find_all_overlaps(NumericRange(19, 21)): ... print(x) - 20..24 - 18..22 - 16..28 - 21..27 + [20..24] + [18..22] + [16..28] + [21..27] """ if self.root is None: @@ -198,15 +218,18 @@ class AugmentedIntervalTree(bst.BinarySearchTree): if root is None: return None - if root.value.overlaps_with(x): - yield root.value + rv = AugmentedIntervalTree._assert_value_must_be_range(root.value) + if rv.overlaps_with(x): + yield rv if root.left: - if root.left.value.highest_in_subtree >= x.low: + lv = AugmentedIntervalTree._assert_value_must_be_range(root.left.value) + if lv.highest_in_subtree >= x.low: yield from self._find_all_overlaps(root.left, x) if root.right: - if root.right.value.highest_in_subtree >= x.low: + rv = AugmentedIntervalTree._assert_value_must_be_range(root.right.value) + if rv.highest_in_subtree >= x.low: yield from self._find_all_overlaps(root.right, x) -- 2.47.1