Improve type hints in bst.py.
authorScott Gasch <[email protected]>
Sat, 6 May 2023 21:14:32 +0000 (14:14 -0700)
committerScott Gasch <[email protected]>
Sat, 6 May 2023 21:14:32 +0000 (14:14 -0700)
src/pyutils/collectionz/bst.py

index 5fd36e0f8172761095606a7c93f34312a81389f1..9538708728af4968dbab9cdf79d89a1c27f96af8 100644 (file)
@@ -4,22 +4,41 @@
 
 """A binary search tree implementation."""
 
 
 """A binary search tree implementation."""
 
-from typing import Any, Generator, List, Optional
+from abc import ABCMeta, abstractmethod
+from typing import Any, Generator, List, Optional, TypeVar
+
+
+class Comparable(metaclass=ABCMeta):
+    @abstractmethod
+    def __lt__(self, other: Any) -> bool:
+        pass
+
+    @abstractmethod
+    def __le__(self, other: Any) -> bool:
+        pass
+
+    @abstractmethod
+    def __eq__(self, other: Any) -> bool:
+        pass
+
+
+ComparableNodeValue = TypeVar('ComparableNodeValue', bound=Comparable)
 
 
 class Node(object):
 
 
 class Node(object):
-    def __init__(self, value: Any) -> None:
-        """
-        A BST node.  Note that value can be anything as long as it
-        is comparable.  Check out :meth:`functools.total_ordering`
+    def __init__(self, value: ComparableNodeValue) -> None:
+        """A BST node.  Note that value can be anything as long as it
+        is comparable with other instances of itself.  Check out
+        :meth:`functools.total_ordering`
         (https://docs.python.org/3/library/functools.html#functools.total_ordering)
 
         Args:
             value: a reference to the value of the node.
         (https://docs.python.org/3/library/functools.html#functools.total_ordering)
 
         Args:
             value: a reference to the value of the node.
+
         """
         self.left: Optional[Node] = None
         self.right: Optional[Node] = None
         """
         self.left: Optional[Node] = None
         self.right: Optional[Node] = None
-        self.value = value
+        self.value: ComparableNodeValue = value
 
 
 class BinarySearchTree(object):
 
 
 class BinarySearchTree(object):
@@ -40,7 +59,7 @@ class BinarySearchTree(object):
         """This is called immediately _after_ a new node is inserted."""
         pass
 
         """This is called immediately _after_ a new node is inserted."""
         pass
 
-    def insert(self, value: Any) -> None:
+    def insert(self, value: ComparableNodeValue) -> None:
         """
         Insert something into the tree.
 
         """
         Insert something into the tree.
 
@@ -65,7 +84,7 @@ class BinarySearchTree(object):
         else:
             self._insert(value, self.root)
 
         else:
             self._insert(value, self.root)
 
-    def _insert(self, value: Any, node: Node):
+    def _insert(self, value: ComparableNodeValue, node: Node):
         """Insertion helper"""
         if value < node.value:
             if node.left is not None:
         """Insertion helper"""
         if value < node.value:
             if node.left is not None:
@@ -82,7 +101,7 @@ class BinarySearchTree(object):
                 self.count += 1
                 self._on_insert(node, node.right)
 
                 self.count += 1
                 self._on_insert(node, node.right)
 
-    def __getitem__(self, value: Any) -> Optional[Node]:
+    def __getitem__(self, value: ComparableNodeValue) -> Optional[Node]:
         """
         Find an item in the tree and return its Node.  Returns
         None if the item is not in the tree.
         """
         Find an item in the tree and return its Node.  Returns
         None if the item is not in the tree.
@@ -103,7 +122,7 @@ class BinarySearchTree(object):
             return self._find(value, self.root)
         return None
 
             return self._find(value, self.root)
         return None
 
-    def _find(self, value: Any, node: Node) -> Optional[Node]:
+    def _find(self, value: ComparableNodeValue, node: Node) -> Optional[Node]:
         """Find helper"""
         if value == node.value:
             return node
         """Find helper"""
         if value == node.value:
             return node
@@ -114,7 +133,7 @@ class BinarySearchTree(object):
         return None
 
     def _find_lowest_value_greater_than_or_equal_to(
         return None
 
     def _find_lowest_value_greater_than_or_equal_to(
-        self, target: Any, node: Optional[Node]
+        self, target: ComparableNodeValue, node: Optional[Node]
     ) -> Optional[Node]:
         """Find helper that returns the lowest node that is greater
         than or equal to the target value.  Returns None if target is
     ) -> Optional[Node]:
         """Find helper that returns the lowest node that is greater
         than or equal to the target value.  Returns None if target is
@@ -244,7 +263,7 @@ class BinarySearchTree(object):
         """
         return self._parent_path(self.root, node)
 
         """
         return self._parent_path(self.root, node)
 
-    def __delitem__(self, value: Any) -> bool:
+    def __delitem__(self, value: ComparableNodeValue) -> bool:
         """
         Delete an item from the tree and preserve the BST property.
 
         """
         Delete an item from the tree and preserve the BST property.
 
@@ -334,7 +353,9 @@ class BinarySearchTree(object):
         """This is called just after deleted was deleted from the tree"""
         pass
 
         """This is called just after deleted was deleted from the tree"""
         pass
 
-    def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool:
+    def _delete(
+        self, value: ComparableNodeValue, parent: Optional[Node], node: Node
+    ) -> bool:
         """Delete helper"""
         if node.value == value:
 
         """Delete helper"""
         if node.value == value:
 
@@ -414,7 +435,7 @@ class BinarySearchTree(object):
         """
         return self.count
 
         """
         return self.count
 
-    def __contains__(self, value: Any) -> bool:
+    def __contains__(self, value: ComparableNodeValue) -> bool:
         """
         Returns:
             True if the item is in the tree; False otherwise.
         """
         Returns:
             True if the item is in the tree; False otherwise.
@@ -661,7 +682,9 @@ class BinarySearchTree(object):
             node = ancestor
         return None
 
             node = ancestor
         return None
 
-    def get_nodes_in_range_inclusive(self, lower: Any, upper: Any):
+    def get_nodes_in_range_inclusive(
+        self, lower: ComparableNodeValue, upper: ComparableNodeValue
+    ):
         """
         >>> t = BinarySearchTree()
         >>> t.insert(50)
         """
         >>> t = BinarySearchTree()
         >>> t.insert(50)