Improve type hints in bst.py.
authorScott Gasch <[email protected]>
Sat, 6 May 2023 21:14:32 +0000 (14:14 -0700)
committerScott Gasch <[email protected]>
Sat, 6 May 2023 21:14:32 +0000 (14:14 -0700)
src/pyutils/collectionz/bst.py

index 5fd36e0f8172761095606a7c93f34312a81389f1..9538708728af4968dbab9cdf79d89a1c27f96af8 100644 (file)
@@ -4,22 +4,41 @@
 
 """A binary search tree implementation."""
 
-from typing import Any, Generator, List, Optional
+from abc import ABCMeta, abstractmethod
+from typing import Any, Generator, List, Optional, TypeVar
+
+
+class Comparable(metaclass=ABCMeta):
+    @abstractmethod
+    def __lt__(self, other: Any) -> bool:
+        pass
+
+    @abstractmethod
+    def __le__(self, other: Any) -> bool:
+        pass
+
+    @abstractmethod
+    def __eq__(self, other: Any) -> bool:
+        pass
+
+
+ComparableNodeValue = TypeVar('ComparableNodeValue', bound=Comparable)
 
 
 class Node(object):
-    def __init__(self, value: Any) -> None:
-        """
-        A BST node.  Note that value can be anything as long as it
-        is comparable.  Check out :meth:`functools.total_ordering`
+    def __init__(self, value: ComparableNodeValue) -> None:
+        """A BST node.  Note that value can be anything as long as it
+        is comparable with other instances of itself.  Check out
+        :meth:`functools.total_ordering`
         (https://docs.python.org/3/library/functools.html#functools.total_ordering)
 
         Args:
             value: a reference to the value of the node.
+
         """
         self.left: Optional[Node] = None
         self.right: Optional[Node] = None
-        self.value = value
+        self.value: ComparableNodeValue = value
 
 
 class BinarySearchTree(object):
@@ -40,7 +59,7 @@ class BinarySearchTree(object):
         """This is called immediately _after_ a new node is inserted."""
         pass
 
-    def insert(self, value: Any) -> None:
+    def insert(self, value: ComparableNodeValue) -> None:
         """
         Insert something into the tree.
 
@@ -65,7 +84,7 @@ class BinarySearchTree(object):
         else:
             self._insert(value, self.root)
 
-    def _insert(self, value: Any, node: Node):
+    def _insert(self, value: ComparableNodeValue, node: Node):
         """Insertion helper"""
         if value < node.value:
             if node.left is not None:
@@ -82,7 +101,7 @@ class BinarySearchTree(object):
                 self.count += 1
                 self._on_insert(node, node.right)
 
-    def __getitem__(self, value: Any) -> Optional[Node]:
+    def __getitem__(self, value: ComparableNodeValue) -> Optional[Node]:
         """
         Find an item in the tree and return its Node.  Returns
         None if the item is not in the tree.
@@ -103,7 +122,7 @@ class BinarySearchTree(object):
             return self._find(value, self.root)
         return None
 
-    def _find(self, value: Any, node: Node) -> Optional[Node]:
+    def _find(self, value: ComparableNodeValue, node: Node) -> Optional[Node]:
         """Find helper"""
         if value == node.value:
             return node
@@ -114,7 +133,7 @@ class BinarySearchTree(object):
         return None
 
     def _find_lowest_value_greater_than_or_equal_to(
-        self, target: Any, node: Optional[Node]
+        self, target: ComparableNodeValue, node: Optional[Node]
     ) -> Optional[Node]:
         """Find helper that returns the lowest node that is greater
         than or equal to the target value.  Returns None if target is
@@ -244,7 +263,7 @@ class BinarySearchTree(object):
         """
         return self._parent_path(self.root, node)
 
-    def __delitem__(self, value: Any) -> bool:
+    def __delitem__(self, value: ComparableNodeValue) -> bool:
         """
         Delete an item from the tree and preserve the BST property.
 
@@ -334,7 +353,9 @@ class BinarySearchTree(object):
         """This is called just after deleted was deleted from the tree"""
         pass
 
-    def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool:
+    def _delete(
+        self, value: ComparableNodeValue, parent: Optional[Node], node: Node
+    ) -> bool:
         """Delete helper"""
         if node.value == value:
 
@@ -414,7 +435,7 @@ class BinarySearchTree(object):
         """
         return self.count
 
-    def __contains__(self, value: Any) -> bool:
+    def __contains__(self, value: ComparableNodeValue) -> bool:
         """
         Returns:
             True if the item is in the tree; False otherwise.
@@ -661,7 +682,9 @@ class BinarySearchTree(object):
             node = ancestor
         return None
 
-    def get_nodes_in_range_inclusive(self, lower: Any, upper: Any):
+    def get_nodes_in_range_inclusive(
+        self, lower: ComparableNodeValue, upper: ComparableNodeValue
+    ):
         """
         >>> t = BinarySearchTree()
         >>> t.insert(50)