From: Scott Gasch Date: Sat, 6 May 2023 21:14:32 +0000 (-0700) Subject: Improve type hints in bst.py. X-Git-Url: https://wannabe.guru.org/gitweb/?a=commitdiff_plain;h=6ddfe8bdb500437c0f0a025b6f7d389d94e395ed;p=pyutils.git Improve type hints in bst.py. --- diff --git a/src/pyutils/collectionz/bst.py b/src/pyutils/collectionz/bst.py index 5fd36e0..9538708 100644 --- a/src/pyutils/collectionz/bst.py +++ b/src/pyutils/collectionz/bst.py @@ -4,22 +4,41 @@ """A binary search tree implementation.""" -from typing import Any, Generator, List, Optional +from abc import ABCMeta, abstractmethod +from typing import Any, Generator, List, Optional, TypeVar + + +class Comparable(metaclass=ABCMeta): + @abstractmethod + def __lt__(self, other: Any) -> bool: + pass + + @abstractmethod + def __le__(self, other: Any) -> bool: + pass + + @abstractmethod + def __eq__(self, other: Any) -> bool: + pass + + +ComparableNodeValue = TypeVar('ComparableNodeValue', bound=Comparable) class Node(object): - def __init__(self, value: Any) -> None: - """ - A BST node. Note that value can be anything as long as it - is comparable. Check out :meth:`functools.total_ordering` + def __init__(self, value: ComparableNodeValue) -> None: + """A BST node. Note that value can be anything as long as it + is comparable with other instances of itself. Check out + :meth:`functools.total_ordering` (https://docs.python.org/3/library/functools.html#functools.total_ordering) Args: value: a reference to the value of the node. + """ self.left: Optional[Node] = None self.right: Optional[Node] = None - self.value = value + self.value: ComparableNodeValue = value class BinarySearchTree(object): @@ -40,7 +59,7 @@ class BinarySearchTree(object): """This is called immediately _after_ a new node is inserted.""" pass - def insert(self, value: Any) -> None: + def insert(self, value: ComparableNodeValue) -> None: """ Insert something into the tree. @@ -65,7 +84,7 @@ class BinarySearchTree(object): else: self._insert(value, self.root) - def _insert(self, value: Any, node: Node): + def _insert(self, value: ComparableNodeValue, node: Node): """Insertion helper""" if value < node.value: if node.left is not None: @@ -82,7 +101,7 @@ class BinarySearchTree(object): self.count += 1 self._on_insert(node, node.right) - def __getitem__(self, value: Any) -> Optional[Node]: + def __getitem__(self, value: ComparableNodeValue) -> Optional[Node]: """ Find an item in the tree and return its Node. Returns None if the item is not in the tree. @@ -103,7 +122,7 @@ class BinarySearchTree(object): return self._find(value, self.root) return None - def _find(self, value: Any, node: Node) -> Optional[Node]: + def _find(self, value: ComparableNodeValue, node: Node) -> Optional[Node]: """Find helper""" if value == node.value: return node @@ -114,7 +133,7 @@ class BinarySearchTree(object): return None def _find_lowest_value_greater_than_or_equal_to( - self, target: Any, node: Optional[Node] + self, target: ComparableNodeValue, node: Optional[Node] ) -> Optional[Node]: """Find helper that returns the lowest node that is greater than or equal to the target value. Returns None if target is @@ -244,7 +263,7 @@ class BinarySearchTree(object): """ return self._parent_path(self.root, node) - def __delitem__(self, value: Any) -> bool: + def __delitem__(self, value: ComparableNodeValue) -> bool: """ Delete an item from the tree and preserve the BST property. @@ -334,7 +353,9 @@ class BinarySearchTree(object): """This is called just after deleted was deleted from the tree""" pass - def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool: + def _delete( + self, value: ComparableNodeValue, parent: Optional[Node], node: Node + ) -> bool: """Delete helper""" if node.value == value: @@ -414,7 +435,7 @@ class BinarySearchTree(object): """ return self.count - def __contains__(self, value: Any) -> bool: + def __contains__(self, value: ComparableNodeValue) -> bool: """ Returns: True if the item is in the tree; False otherwise. @@ -661,7 +682,9 @@ class BinarySearchTree(object): node = ancestor return None - def get_nodes_in_range_inclusive(self, lower: Any, upper: Any): + def get_nodes_in_range_inclusive( + self, lower: ComparableNodeValue, upper: ComparableNodeValue + ): """ >>> t = BinarySearchTree() >>> t.insert(50)