Use Protocol to implement the interface typevar here instead.
authorScott Gasch <[email protected]>
Mon, 8 May 2023 02:43:26 +0000 (19:43 -0700)
committerScott Gasch <[email protected]>
Mon, 8 May 2023 02:43:26 +0000 (19:43 -0700)
src/pyutils/collectionz/bst.py
src/pyutils/collectionz/interval_tree.py

index cefbf59ba12c4050a0c63d6601ba941e265a8543..246b605c102b123a55b69a80c5f1a6685cb53271 100644 (file)
@@ -4,11 +4,11 @@
 
 """A binary search tree implementation."""
 
-from abc import ABCMeta, abstractmethod
-from typing import Any, Generator, List, Optional, TypeVar
+from abc import abstractmethod
+from typing import Any, Generator, List, Optional, Protocol
 
 
-class Comparable(metaclass=ABCMeta):
+class Comparable(Protocol):
     @abstractmethod
     def __lt__(self, other: Any) -> bool:
         pass
@@ -22,11 +22,8 @@ class Comparable(metaclass=ABCMeta):
         pass
 
 
-ComparableNodeValue = TypeVar('ComparableNodeValue', bound=Comparable)
-
-
-class Node(object):
-    def __init__(self, value: ComparableNodeValue) -> None:
+class Node:
+    def __init__(self, value: Comparable) -> None:
         """A BST node.  Note that value can be anything as long as it
         is comparable with other instances of itself.  Check out
         :meth:`functools.total_ordering`
@@ -38,7 +35,7 @@ class Node(object):
         """
         self.left: Optional[Node] = None
         self.right: Optional[Node] = None
-        self.value: ComparableNodeValue = value
+        self.value: Comparable = value
 
 
 class BinarySearchTree(object):
@@ -59,7 +56,7 @@ class BinarySearchTree(object):
         """This is called immediately _after_ a new node is inserted."""
         pass
 
-    def insert(self, value: ComparableNodeValue) -> None:
+    def insert(self, value: Comparable) -> None:
         """
         Insert something into the tree.
 
@@ -84,7 +81,7 @@ class BinarySearchTree(object):
         else:
             self._insert(value, self.root)
 
-    def _insert(self, value: ComparableNodeValue, node: Node):
+    def _insert(self, value: Comparable, node: Node):
         """Insertion helper"""
         if value < node.value:
             if node.left is not None:
@@ -101,7 +98,7 @@ class BinarySearchTree(object):
                 self.count += 1
                 self._on_insert(node, node.right)
 
-    def __getitem__(self, value: ComparableNodeValue) -> Optional[Node]:
+    def __getitem__(self, value: Comparable) -> Optional[Node]:
         """
         Find an item in the tree and return its Node.  Returns
         None if the item is not in the tree.
@@ -122,7 +119,7 @@ class BinarySearchTree(object):
             return self._find_exact(value, self.root)
         return None
 
-    def _find_exact(self, target: ComparableNodeValue, node: Node) -> Optional[Node]:
+    def _find_exact(self, target: Comparable, node: Node) -> Optional[Node]:
         """Recursively traverse the tree looking for a node with the
         target value.  Return that node if it exists, otherwise return
         None."""
@@ -136,7 +133,7 @@ class BinarySearchTree(object):
         return None
 
     def _find_lowest_node_less_than_or_equal_to(
-        self, target: ComparableNodeValue, node: Optional[Node]
+        self, target: Comparable, node: Optional[Node]
     ) -> Optional[Node]:
         """Find helper that returns the lowest node that is less
         than or equal to the target value.  Returns None if target is
@@ -196,7 +193,7 @@ class BinarySearchTree(object):
             return self._find_lowest_node_less_than_or_equal_to(target, node.left)
 
     def _find_lowest_node_greater_than_or_equal_to(
-        self, target: ComparableNodeValue, node: Optional[Node]
+        self, target: Comparable, node: Optional[Node]
     ) -> Optional[Node]:
         """Find helper that returns the lowest node that is greater
         than or equal to the target value.  Returns None if target is
@@ -329,7 +326,7 @@ class BinarySearchTree(object):
         """
         return self._parent_path(self.root, node)
 
-    def __delitem__(self, value: ComparableNodeValue) -> bool:
+    def __delitem__(self, value: Comparable) -> bool:
         """
         Delete an item from the tree and preserve the BST property.
 
@@ -419,9 +416,7 @@ class BinarySearchTree(object):
         """This is called just after deleted was deleted from the tree"""
         pass
 
-    def _delete(
-        self, value: ComparableNodeValue, parent: Optional[Node], node: Node
-    ) -> bool:
+    def _delete(self, value: Comparable, parent: Optional[Node], node: Node) -> bool:
         """Delete helper"""
         if node.value == value:
 
@@ -501,7 +496,7 @@ class BinarySearchTree(object):
         """
         return self.count
 
-    def __contains__(self, value: ComparableNodeValue) -> bool:
+    def __contains__(self, value: Comparable) -> bool:
         """
         Returns:
             True if the item is in the tree; False otherwise.
@@ -748,9 +743,7 @@ class BinarySearchTree(object):
             node = ancestor
         return None
 
-    def get_nodes_in_range_inclusive(
-        self, lower: ComparableNodeValue, upper: ComparableNodeValue
-    ):
+    def get_nodes_in_range_inclusive(self, lower: Comparable, upper: Comparable):
         """
         >>> t = BinarySearchTree()
         >>> t.insert(50)
index 0a88a3d844b2ba71b84f0e2a244cd56caa17cd9b..9542e2161bbfa6405866c0e76270d71130513a1c 100644 (file)
@@ -18,7 +18,7 @@ from pyutils.typez.simple import Numeric
 
 
 @total_ordering
-class NumericRange(object):
+class NumericRange(bst.Comparable):
     """Essentially a tuple of numbers denoting a range with some added
     helper methods on it."""
 
@@ -36,13 +36,12 @@ class NumericRange(object):
 
         """
         if low > high:
-            temp: Numeric = low
-            low = high
-            high = temp
+            low, high = high, low
         self.low: Numeric = low
         self.high: Numeric = high
         self.highest_in_subtree: Numeric = high
 
+    @overrides
     def __lt__(self, other: NumericRange) -> bool:
         """
         Returns:
@@ -63,6 +62,12 @@ class NumericRange(object):
             return False
         return self.low == other.low and self.high == other.high
 
+    @overrides
+    def __le__(self, other: object) -> bool:
+        if not isinstance(other, NumericRange):
+            return False
+        return self < other or self == other
+
     def overlaps_with(self, other: NumericRange) -> bool:
         """
         Returns:
@@ -71,35 +76,48 @@ class NumericRange(object):
         return self.low <= other.high and self.high >= other.low
 
     def __repr__(self) -> str:
-        return f"{self.low}..{self.high}"
+        return f"[{self.low}..{self.high}]"
 
 
 class AugmentedIntervalTree(bst.BinarySearchTree):
     @staticmethod
-    def _assert_value_must_be_range(value: Any) -> None:
+    def _assert_value_must_be_range(value: Any) -> NumericRange:
         if not isinstance(value, NumericRange):
             raise Exception(
                 "AugmentedIntervalTree expects to use NumericRanges, see bst for a "
                 + "general purpose tree usable for other types."
             )
+        return value
 
     @overrides
     def _on_insert(self, parent: Optional[bst.Node], new: bst.Node) -> None:
-        AugmentedIntervalTree._assert_value_must_be_range(new.value)
+        nv: NumericRange = AugmentedIntervalTree._assert_value_must_be_range(new.value)
         for ancestor in self.parent_path(new):
             assert ancestor
-            if new.value.high > ancestor.value.highest_in_subtree:
-                ancestor.value.highest_in_subtree = new.value.high
+            av: NumericRange = AugmentedIntervalTree._assert_value_must_be_range(
+                ancestor.value
+            )
+            if nv.high > av.highest_in_subtree:
+                av.highest_in_subtree = nv.high
 
     @overrides
     def _on_delete(self, parent: Optional[bst.Node], deleted: bst.Node) -> None:
         if parent:
-            new_highest_candidates = [parent.value.high]
+            pv: NumericRange = AugmentedIntervalTree._assert_value_must_be_range(
+                parent.value
+            )
+            new_highest_candidates = [pv.high]
             if parent.left:
-                new_highest_candidates.append(parent.left.value.highest_in_subtree)
+                lv: NumericRange = AugmentedIntervalTree._assert_value_must_be_range(
+                    parent.left.value
+                )
+                new_highest_candidates.append(lv.highest_in_subtree)
             if parent.right:
-                new_highest_candidates.append(parent.right.value.highest_in_subtree)
-            parent.value.highest_in_subtree = max(new_highest_candidates)
+                rv: NumericRange = AugmentedIntervalTree._assert_value_must_be_range(
+                    parent.right.value
+                )
+                new_highest_candidates.append(rv.highest_in_subtree)
+            pv.highest_in_subtree = max(new_highest_candidates)
 
     def find_one_overlap(self, to_find: NumericRange) -> Optional[NumericRange]:
         """Identify and return one overlapping node from the tree.
@@ -124,7 +142,7 @@ class AugmentedIntervalTree(bst.BinarySearchTree):
         >>> tree.insert(NumericRange(16, 28))
         >>> tree.insert(NumericRange(21, 27))
         >>> tree.find_one_overlap(NumericRange(6, 7))
-        1..30
+        [1..30]
 
         """
         return self._find_one_overlap(self.root, to_find)
@@ -135,11 +153,13 @@ class AugmentedIntervalTree(bst.BinarySearchTree):
         if root is None:
             return None
 
-        if root.value.overlaps_with(x):
-            return root.value
+        rv = AugmentedIntervalTree._assert_value_must_be_range(root.value)
+        if rv.overlaps_with(x):
+            return rv
 
         if root.left:
-            if root.left.value.highest_in_subtree >= x.low:
+            lv = AugmentedIntervalTree._assert_value_must_be_range(root.left.value)
+            if lv.highest_in_subtree >= x.low:
                 return self._find_one_overlap(root.left, x)
 
         if root.right:
@@ -173,19 +193,19 @@ class AugmentedIntervalTree(bst.BinarySearchTree):
         >>> tree.insert(NumericRange(21, 27))
         >>> for x in tree.find_all_overlaps(NumericRange(19, 21)):
         ...     print(x)
-        20..24
-        18..22
-        1..30
-        16..28
-        21..27
+        [20..24]
+        [18..22]
+        [1..30]
+        [16..28]
+        [21..27]
 
         >>> del tree[NumericRange(1, 30)]
         >>> for x in tree.find_all_overlaps(NumericRange(19, 21)):
         ...     print(x)
-        20..24
-        18..22
-        16..28
-        21..27
+        [20..24]
+        [18..22]
+        [16..28]
+        [21..27]
 
         """
         if self.root is None:
@@ -198,15 +218,18 @@ class AugmentedIntervalTree(bst.BinarySearchTree):
         if root is None:
             return None
 
-        if root.value.overlaps_with(x):
-            yield root.value
+        rv = AugmentedIntervalTree._assert_value_must_be_range(root.value)
+        if rv.overlaps_with(x):
+            yield rv
 
         if root.left:
-            if root.left.value.highest_in_subtree >= x.low:
+            lv = AugmentedIntervalTree._assert_value_must_be_range(root.left.value)
+            if lv.highest_in_subtree >= x.low:
                 yield from self._find_all_overlaps(root.left, x)
 
         if root.right:
-            if root.right.value.highest_in_subtree >= x.low:
+            rv = AugmentedIntervalTree._assert_value_must_be_range(root.right.value)
+            if rv.highest_in_subtree >= x.low:
                 yield from self._find_all_overlaps(root.right, x)