Adds a __repr__ to graph.
[pyutils.git] / src / pyutils / collectionz / bst.py
index f492dfdab7aa8cea4c8a08a1592bbd93e05aadce..74c328da05f9729034362a74b91ff35d5f14d68e 100644 (file)
@@ -4,22 +4,25 @@
 
 """A binary search tree implementation."""
 
-from typing import Any, Generator, List, Optional
+from typing import Generator, List, Optional
 
+from pyutils.typez.typing import Comparable
 
-class Node(object):
-    def __init__(self, value: Any) -> None:
-        """
-        A BST node.  Note that value can be anything as long as it
-        is comparable.  Check out :meth:`functools.total_ordering`
-        (https://docs.python.org/3/library/functools.html#functools.total_ordering)
+
+class Node:
+    def __init__(self, value: Comparable) -> None:
+        """A BST node.  Just a left and right reference along with a
+        value.  Note that value can be anything as long as it
+        is :class:`Comparable` with other instances of itself.
 
         Args:
-            value: a reference to the value of the node.
+            value: a reference to the value of the node.  Must be
+                :class:`Comparable` to other values.
+
         """
         self.left: Optional[Node] = None
         self.right: Optional[Node] = None
-        self.value = value
+        self.value: Comparable = value
 
 
 class BinarySearchTree(object):
@@ -40,9 +43,9 @@ class BinarySearchTree(object):
         """This is called immediately _after_ a new node is inserted."""
         pass
 
-    def insert(self, value: Any) -> None:
+    def insert(self, value: Comparable) -> None:
         """
-        Insert something into the tree.
+        Insert something into the tree in :math:`O(log_2 n)` time.
 
         Args:
             value: the value to be inserted.
@@ -65,7 +68,7 @@ class BinarySearchTree(object):
         else:
             self._insert(value, self.root)
 
-    def _insert(self, value: Any, node: Node):
+    def _insert(self, value: Comparable, node: Node):
         """Insertion helper"""
         if value < node.value:
             if node.left is not None:
@@ -82,10 +85,11 @@ class BinarySearchTree(object):
                 self.count += 1
                 self._on_insert(node, node.right)
 
-    def __getitem__(self, value: Any) -> Optional[Node]:
+    def __getitem__(self, value: Comparable) -> Optional[Node]:
         """
-        Find an item in the tree and return its Node.  Returns
-        None if the item is not in the tree.
+        Find an item in the tree and return its Node in
+        :math:`O(log_2 n)` time.  Returns None if the item is not in
+        the tree.
 
         >>> t = BinarySearchTree()
         >>> t[99]
@@ -100,19 +104,144 @@ class BinarySearchTree(object):
 
         """
         if self.root is not None:
-            return self._find(value, self.root)
+            return self._find_exact(value, self.root)
         return None
 
-    def _find(self, value: Any, node: Node) -> Optional[Node]:
-        """Find helper"""
-        if value == node.value:
+    def _find_exact(self, target: Comparable, node: Node) -> Optional[Node]:
+        """Recursively traverse the tree looking for a node with the
+        target value.  Return that node if it exists, otherwise return
+        None."""
+
+        if target == node.value:
             return node
-        elif value < node.value and node.left is not None:
-            return self._find(value, node.left)
-        elif value > node.value and node.right is not None:
-            return self._find(value, node.right)
+        elif target < node.value and node.left is not None:
+            return self._find_exact(target, node.left)
+        elif target > node.value and node.right is not None:
+            return self._find_exact(target, node.right)
         return None
 
+    def _find_lowest_node_less_than_or_equal_to(
+        self, target: Comparable, node: Optional[Node]
+    ) -> Optional[Node]:
+        """Find helper that returns the lowest node that is less
+        than or equal to the target value.  Returns None if target is
+        lower than the lowest node in the tree.
+
+        >>> t = BinarySearchTree()
+        >>> t.insert(50)
+        >>> t.insert(75)
+        >>> t.insert(25)
+        >>> t.insert(66)
+        >>> t.insert(22)
+        >>> t.insert(13)
+        >>> t.insert(85)
+        >>> t
+        50
+        ├──25
+        │  └──22
+        │     └──13
+        └──75
+           ├──66
+           └──85
+
+        >>> t._find_lowest_node_less_than_or_equal_to(48, t.root).value
+        25
+        >>> t._find_lowest_node_less_than_or_equal_to(55, t.root).value
+        50
+        >>> t._find_lowest_node_less_than_or_equal_to(100, t.root).value
+        85
+        >>> t._find_lowest_node_less_than_or_equal_to(24, t.root).value
+        22
+        >>> t._find_lowest_node_less_than_or_equal_to(20, t.root).value
+        13
+        >>> t._find_lowest_node_less_than_or_equal_to(72, t.root).value
+        66
+        >>> t._find_lowest_node_less_than_or_equal_to(78, t.root).value
+        75
+        >>> t._find_lowest_node_less_than_or_equal_to(12, t.root) is None
+        True
+
+        """
+
+        if not node:
+            return None
+
+        if target == node.value:
+            return node
+
+        elif target > node.value:
+            if below := self._find_lowest_node_less_than_or_equal_to(
+                target, node.right
+            ):
+                return below
+            else:
+                return node
+
+        else:
+            return self._find_lowest_node_less_than_or_equal_to(target, node.left)
+
+    def _find_lowest_node_greater_than_or_equal_to(
+        self, target: Comparable, node: Optional[Node]
+    ) -> Optional[Node]:
+        """Find helper that returns the lowest node that is greater
+        than or equal to the target value.  Returns None if target is
+        higher than the greatest node in the tree.
+
+        >>> t = BinarySearchTree()
+        >>> t.insert(50)
+        >>> t.insert(75)
+        >>> t.insert(25)
+        >>> t.insert(66)
+        >>> t.insert(22)
+        >>> t.insert(13)
+        >>> t.insert(85)
+        >>> t
+        50
+        ├──25
+        │  └──22
+        │     └──13
+        └──75
+           ├──66
+           └──85
+
+        >>> t._find_lowest_node_greater_than_or_equal_to(48, t.root).value
+        50
+        >>> t._find_lowest_node_greater_than_or_equal_to(55, t.root).value
+        66
+        >>> t._find_lowest_node_greater_than_or_equal_to(1, t.root).value
+        13
+        >>> t._find_lowest_node_greater_than_or_equal_to(24, t.root).value
+        25
+        >>> t._find_lowest_node_greater_than_or_equal_to(20, t.root).value
+        22
+        >>> t._find_lowest_node_greater_than_or_equal_to(72, t.root).value
+        75
+        >>> t._find_lowest_node_greater_than_or_equal_to(78, t.root).value
+        85
+        >>> t._find_lowest_node_greater_than_or_equal_to(95, t.root) is None
+        True
+
+        """
+
+        if not node:
+            return None
+
+        if target == node.value:
+            return node
+
+        elif target > node.value:
+            return self._find_lowest_node_greater_than_or_equal_to(target, node.right)
+
+        # If target < this node's value, either this node is the
+        # answer or the answer is in this node's left subtree.
+        else:
+            if below := self._find_lowest_node_greater_than_or_equal_to(
+                target, node.left
+            ):
+                return below
+            else:
+                return node
+
     def _parent_path(
         self, current: Optional[Node], target: Node
     ) -> List[Optional[Node]]:
@@ -131,14 +260,14 @@ class BinarySearchTree(object):
             return ret
 
     def parent_path(self, node: Node) -> List[Optional[Node]]:
-        """Get a node's parent path.
+        """Get a node's parent path in :math:`O(log_2 n)` time.
 
         Args:
-            node: the node to check
+            node: the node whose parent path should be returned.
 
         Returns:
             a list of nodes representing the path from
-            the tree's root to the node.
+            the tree's root to the given node.
 
         .. note::
 
@@ -185,9 +314,10 @@ class BinarySearchTree(object):
         """
         return self._parent_path(self.root, node)
 
-    def __delitem__(self, value: Any) -> bool:
+    def __delitem__(self, value: Comparable) -> bool:
         """
-        Delete an item from the tree and preserve the BST property.
+        Delete an item from the tree and preserve the BST property in
+        :math:`O(log_2 n) time`.
 
         Args:
             value: the value of the node to be deleted.
@@ -258,6 +388,9 @@ class BinarySearchTree(object):
         └──85
            └──66
 
+        >>> t.__delitem__(85)
+        True
+
         >>> t.__delitem__(99)
         False
 
@@ -275,7 +408,7 @@ class BinarySearchTree(object):
         """This is called just after deleted was deleted from the tree"""
         pass
 
-    def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool:
+    def _delete(self, value: Comparable, parent: Optional[Node], node: Node) -> bool:
         """Delete helper"""
         if node.value == value:
 
@@ -314,14 +447,18 @@ class BinarySearchTree(object):
                 self._on_delete(parent, node)
                 return True
 
-            # Node has both a left and right.
+            # Node has both a left and right; get the successor node
+            # to this one and put it here (deleting the successor's
+            # old node).  Because these operations are happening only
+            # in the subtree underneath of node, I'm still calling
+            # this delete an O(log_2 n) operation in the docs.
             else:
                 assert node.left is not None and node.right is not None
-                descendent = node.right
-                while descendent.left is not None:
-                    descendent = descendent.left
-                node.value = descendent.value
+                successor = self.get_next_node(node)
+                assert successor is not None
+                node.value = successor.value
                 return self._delete(node.value, node, node.right)
+
         elif value < node.value and node.left is not None:
             return self._delete(value, node, node.left)
         elif value > node.value and node.right is not None:
@@ -331,7 +468,7 @@ class BinarySearchTree(object):
     def __len__(self):
         """
         Returns:
-            The count of items in the tree.
+            The count of items in the tree in :math:`O(1)` time.
 
         >>> t = BinarySearchTree()
         >>> len(t)
@@ -355,7 +492,7 @@ class BinarySearchTree(object):
         """
         return self.count
 
-    def __contains__(self, value: Any) -> bool:
+    def __contains__(self, value: Comparable) -> bool:
         """
         Returns:
             True if the item is in the tree; False otherwise.
@@ -485,7 +622,7 @@ class BinarySearchTree(object):
     def iterate_leaves(self):
         """
         Returns:
-            A Gemerator that yielde only the leaf nodes in the
+            A Generator that yields only the leaf nodes in the
             tree.
 
         >>> t = BinarySearchTree()
@@ -545,13 +682,14 @@ class BinarySearchTree(object):
         if self.root is not None:
             yield from self._iterate_by_depth(self.root, depth)
 
-    def get_next_node(self, node: Node) -> Node:
+    def get_next_node(self, node: Node) -> Optional[Node]:
         """
         Args:
             node: the node whose next greater successor is desired
 
         Returns:
             Given a tree node, returns the next greater node in the tree.
+            If the given node is the greatest node in the tree, returns None.
 
         >>> t = BinarySearchTree()
         >>> t.insert(50)
@@ -578,6 +716,10 @@ class BinarySearchTree(object):
         >>> t.get_next_node(n).value
         66
 
+        >>> n = t[75]
+        >>> t.get_next_node(n) is None
+        True
+
         """
         if node.right is not None:
             x = node.right
@@ -595,7 +737,51 @@ class BinarySearchTree(object):
             if node != ancestor.right:
                 return ancestor
             node = ancestor
-        raise Exception()
+        return None
+
+    def get_nodes_in_range_inclusive(
+        self, lower: Comparable, upper: Comparable
+    ) -> Generator[Node, None, None]:
+        """
+        Args:
+            lower: the lower bound of the desired range.
+            upper: the upper bound of the desired range.
+
+        Returns:
+            Generates a sequence of nodes in the desired range.
+
+        >>> t = BinarySearchTree()
+        >>> t.insert(50)
+        >>> t.insert(75)
+        >>> t.insert(25)
+        >>> t.insert(66)
+        >>> t.insert(22)
+        >>> t.insert(13)
+        >>> t.insert(23)
+        >>> t
+        50
+        ├──25
+        │  └──22
+        │     ├──13
+        │     └──23
+        └──75
+           └──66
+
+        >>> for node in t.get_nodes_in_range_inclusive(21, 74):
+        ...     print(node.value)
+        22
+        23
+        25
+        50
+        66
+        """
+        node: Optional[Node] = self._find_lowest_node_greater_than_or_equal_to(
+            lower, self.root
+        )
+        while node:
+            if lower <= node.value <= upper:
+                yield node
+            node = self.get_next_node(node)
 
     def _depth(self, node: Node, sofar: int) -> int:
         depth_left = sofar + 1
@@ -610,7 +796,7 @@ class BinarySearchTree(object):
         """
         Returns:
             The max height (depth) of the tree in plies (edge distance
-            from root).
+            from root) in :math:`O(log_2 n)` time.
 
         >>> t = BinarySearchTree()
         >>> t.depth()