X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=src%2Fpyutils%2Fcollectionz%2Fbst.py;h=74c328da05f9729034362a74b91ff35d5f14d68e;hb=HEAD;hp=4c0bacdd051374a3f700ceba33b4beaad143b956;hpb=993b0992473c12294ed659e52b532e1c8cf9cd1e;p=pyutils.git diff --git a/src/pyutils/collectionz/bst.py b/src/pyutils/collectionz/bst.py index 4c0bacd..74c328d 100644 --- a/src/pyutils/collectionz/bst.py +++ b/src/pyutils/collectionz/bst.py @@ -1,25 +1,28 @@ #!/usr/bin/env python3 -# © Copyright 2021-2022, Scott Gasch +# © Copyright 2021-2023, Scott Gasch """A binary search tree implementation.""" -from typing import Any, Generator, List, Optional +from typing import Generator, List, Optional +from pyutils.typez.typing import Comparable -class Node(object): - def __init__(self, value: Any) -> None: - """ - A BST node. Note that value can be anything as long as it - is comparable. Check out :meth:`functools.total_ordering` - (https://docs.python.org/3/library/functools.html#functools.total_ordering) + +class Node: + def __init__(self, value: Comparable) -> None: + """A BST node. Just a left and right reference along with a + value. Note that value can be anything as long as it + is :class:`Comparable` with other instances of itself. Args: - value: a reference to the value of the node. + value: a reference to the value of the node. Must be + :class:`Comparable` to other values. + """ self.left: Optional[Node] = None self.right: Optional[Node] = None - self.value = value + self.value: Comparable = value class BinarySearchTree(object): @@ -36,9 +39,13 @@ class BinarySearchTree(object): return self.root - def insert(self, value: Any) -> None: + def _on_insert(self, parent: Optional[Node], new: Node) -> None: + """This is called immediately _after_ a new node is inserted.""" + pass + + def insert(self, value: Comparable) -> None: """ - Insert something into the tree. + Insert something into the tree in :math:`O(log_2 n)` time. Args: value: the value to be inserted. @@ -57,10 +64,11 @@ class BinarySearchTree(object): if self.root is None: self.root = Node(value) self.count = 1 + self._on_insert(None, self.root) else: self._insert(value, self.root) - def _insert(self, value: Any, node: Node): + def _insert(self, value: Comparable, node: Node): """Insertion helper""" if value < node.value: if node.left is not None: @@ -68,17 +76,20 @@ class BinarySearchTree(object): else: node.left = Node(value) self.count += 1 + self._on_insert(node, node.left) else: if node.right is not None: self._insert(value, node.right) else: node.right = Node(value) self.count += 1 + self._on_insert(node, node.right) - def __getitem__(self, value: Any) -> Optional[Node]: + def __getitem__(self, value: Comparable) -> Optional[Node]: """ - Find an item in the tree and return its Node. Returns - None if the item is not in the tree. + Find an item in the tree and return its Node in + :math:`O(log_2 n)` time. Returns None if the item is not in + the tree. >>> t = BinarySearchTree() >>> t[99] @@ -93,19 +104,144 @@ class BinarySearchTree(object): """ if self.root is not None: - return self._find(value, self.root) + return self._find_exact(value, self.root) return None - def _find(self, value: Any, node: Node) -> Optional[Node]: - """Find helper""" - if value == node.value: + def _find_exact(self, target: Comparable, node: Node) -> Optional[Node]: + """Recursively traverse the tree looking for a node with the + target value. Return that node if it exists, otherwise return + None.""" + + if target == node.value: return node - elif value < node.value and node.left is not None: - return self._find(value, node.left) - elif value > node.value and node.right is not None: - return self._find(value, node.right) + elif target < node.value and node.left is not None: + return self._find_exact(target, node.left) + elif target > node.value and node.right is not None: + return self._find_exact(target, node.right) return None + def _find_lowest_node_less_than_or_equal_to( + self, target: Comparable, node: Optional[Node] + ) -> Optional[Node]: + """Find helper that returns the lowest node that is less + than or equal to the target value. Returns None if target is + lower than the lowest node in the tree. + + >>> t = BinarySearchTree() + >>> t.insert(50) + >>> t.insert(75) + >>> t.insert(25) + >>> t.insert(66) + >>> t.insert(22) + >>> t.insert(13) + >>> t.insert(85) + >>> t + 50 + ├──25 + │ └──22 + │ └──13 + └──75 + ├──66 + └──85 + + >>> t._find_lowest_node_less_than_or_equal_to(48, t.root).value + 25 + >>> t._find_lowest_node_less_than_or_equal_to(55, t.root).value + 50 + >>> t._find_lowest_node_less_than_or_equal_to(100, t.root).value + 85 + >>> t._find_lowest_node_less_than_or_equal_to(24, t.root).value + 22 + >>> t._find_lowest_node_less_than_or_equal_to(20, t.root).value + 13 + >>> t._find_lowest_node_less_than_or_equal_to(72, t.root).value + 66 + >>> t._find_lowest_node_less_than_or_equal_to(78, t.root).value + 75 + >>> t._find_lowest_node_less_than_or_equal_to(12, t.root) is None + True + + """ + + if not node: + return None + + if target == node.value: + return node + + elif target > node.value: + if below := self._find_lowest_node_less_than_or_equal_to( + target, node.right + ): + return below + else: + return node + + else: + return self._find_lowest_node_less_than_or_equal_to(target, node.left) + + def _find_lowest_node_greater_than_or_equal_to( + self, target: Comparable, node: Optional[Node] + ) -> Optional[Node]: + """Find helper that returns the lowest node that is greater + than or equal to the target value. Returns None if target is + higher than the greatest node in the tree. + + >>> t = BinarySearchTree() + >>> t.insert(50) + >>> t.insert(75) + >>> t.insert(25) + >>> t.insert(66) + >>> t.insert(22) + >>> t.insert(13) + >>> t.insert(85) + >>> t + 50 + ├──25 + │ └──22 + │ └──13 + └──75 + ├──66 + └──85 + + >>> t._find_lowest_node_greater_than_or_equal_to(48, t.root).value + 50 + >>> t._find_lowest_node_greater_than_or_equal_to(55, t.root).value + 66 + >>> t._find_lowest_node_greater_than_or_equal_to(1, t.root).value + 13 + >>> t._find_lowest_node_greater_than_or_equal_to(24, t.root).value + 25 + >>> t._find_lowest_node_greater_than_or_equal_to(20, t.root).value + 22 + >>> t._find_lowest_node_greater_than_or_equal_to(72, t.root).value + 75 + >>> t._find_lowest_node_greater_than_or_equal_to(78, t.root).value + 85 + >>> t._find_lowest_node_greater_than_or_equal_to(95, t.root) is None + True + + """ + + if not node: + return None + + if target == node.value: + return node + + elif target > node.value: + return self._find_lowest_node_greater_than_or_equal_to(target, node.right) + + # If target < this node's value, either this node is the + # answer or the answer is in this node's left subtree. + else: + if below := self._find_lowest_node_greater_than_or_equal_to( + target, node.left + ): + return below + else: + return node + def _parent_path( self, current: Optional[Node], target: Node ) -> List[Optional[Node]]: @@ -124,14 +260,14 @@ class BinarySearchTree(object): return ret def parent_path(self, node: Node) -> List[Optional[Node]]: - """Get a node's parent path. + """Get a node's parent path in :math:`O(log_2 n)` time. Args: - node: the node to check + node: the node whose parent path should be returned. Returns: a list of nodes representing the path from - the tree's root to the node. + the tree's root to the given node. .. note:: @@ -178,9 +314,10 @@ class BinarySearchTree(object): """ return self._parent_path(self.root, node) - def __delitem__(self, value: Any) -> bool: + def __delitem__(self, value: Comparable) -> bool: """ - Delete an item from the tree and preserve the BST property. + Delete an item from the tree and preserve the BST property in + :math:`O(log_2 n) time`. Args: value: the value of the node to be deleted. @@ -251,6 +388,9 @@ class BinarySearchTree(object): └──85 └──66 + >>> t.__delitem__(85) + True + >>> t.__delitem__(99) False @@ -264,9 +404,14 @@ class BinarySearchTree(object): return ret return False - def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool: + def _on_delete(self, parent: Optional[Node], deleted: Node) -> None: + """This is called just after deleted was deleted from the tree""" + pass + + def _delete(self, value: Comparable, parent: Optional[Node], node: Node) -> bool: """Delete helper""" if node.value == value: + # Deleting a leaf node if node.left is None and node.right is None: if parent is not None: @@ -275,6 +420,7 @@ class BinarySearchTree(object): else: assert parent.right == node parent.right = None + self._on_delete(parent, node) return True # Node only has a right. @@ -286,6 +432,7 @@ class BinarySearchTree(object): else: assert parent.right == node parent.right = node.right + self._on_delete(parent, node) return True # Node only has a left. @@ -297,16 +444,21 @@ class BinarySearchTree(object): else: assert parent.right == node parent.right = node.left + self._on_delete(parent, node) return True - # Node has both a left and right. + # Node has both a left and right; get the successor node + # to this one and put it here (deleting the successor's + # old node). Because these operations are happening only + # in the subtree underneath of node, I'm still calling + # this delete an O(log_2 n) operation in the docs. else: assert node.left is not None and node.right is not None - descendent = node.right - while descendent.left is not None: - descendent = descendent.left - node.value = descendent.value + successor = self.get_next_node(node) + assert successor is not None + node.value = successor.value return self._delete(node.value, node, node.right) + elif value < node.value and node.left is not None: return self._delete(value, node, node.left) elif value > node.value and node.right is not None: @@ -316,7 +468,7 @@ class BinarySearchTree(object): def __len__(self): """ Returns: - The count of items in the tree. + The count of items in the tree in :math:`O(1)` time. >>> t = BinarySearchTree() >>> len(t) @@ -340,7 +492,7 @@ class BinarySearchTree(object): """ return self.count - def __contains__(self, value: Any) -> bool: + def __contains__(self, value: Comparable) -> bool: """ Returns: True if the item is in the tree; False otherwise. @@ -470,7 +622,7 @@ class BinarySearchTree(object): def iterate_leaves(self): """ Returns: - A Gemerator that yielde only the leaf nodes in the + A Generator that yields only the leaf nodes in the tree. >>> t = BinarySearchTree() @@ -530,13 +682,14 @@ class BinarySearchTree(object): if self.root is not None: yield from self._iterate_by_depth(self.root, depth) - def get_next_node(self, node: Node) -> Node: + def get_next_node(self, node: Node) -> Optional[Node]: """ Args: node: the node whose next greater successor is desired Returns: Given a tree node, returns the next greater node in the tree. + If the given node is the greatest node in the tree, returns None. >>> t = BinarySearchTree() >>> t.insert(50) @@ -563,6 +716,10 @@ class BinarySearchTree(object): >>> t.get_next_node(n).value 66 + >>> n = t[75] + >>> t.get_next_node(n) is None + True + """ if node.right is not None: x = node.right @@ -580,7 +737,51 @@ class BinarySearchTree(object): if node != ancestor.right: return ancestor node = ancestor - raise Exception() + return None + + def get_nodes_in_range_inclusive( + self, lower: Comparable, upper: Comparable + ) -> Generator[Node, None, None]: + """ + Args: + lower: the lower bound of the desired range. + upper: the upper bound of the desired range. + + Returns: + Generates a sequence of nodes in the desired range. + + >>> t = BinarySearchTree() + >>> t.insert(50) + >>> t.insert(75) + >>> t.insert(25) + >>> t.insert(66) + >>> t.insert(22) + >>> t.insert(13) + >>> t.insert(23) + >>> t + 50 + ├──25 + │ └──22 + │ ├──13 + │ └──23 + └──75 + └──66 + + >>> for node in t.get_nodes_in_range_inclusive(21, 74): + ... print(node.value) + 22 + 23 + 25 + 50 + 66 + """ + node: Optional[Node] = self._find_lowest_node_greater_than_or_equal_to( + lower, self.root + ) + while node: + if lower <= node.value <= upper: + yield node + node = self.get_next_node(node) def _depth(self, node: Node, sofar: int) -> int: depth_left = sofar + 1 @@ -595,7 +796,7 @@ class BinarySearchTree(object): """ Returns: The max height (depth) of the tree in plies (edge distance - from root). + from root) in :math:`O(log_2 n)` time. >>> t = BinarySearchTree() >>> t.depth() @@ -635,11 +836,11 @@ class BinarySearchTree(object): has_right_sibling: bool, ) -> str: if node is not None: - viz = f'\n{padding}{pointer}{node.value}' + viz = f"\n{padding}{pointer}{node.value}" if has_right_sibling: padding += "│ " else: - padding += ' ' + padding += " " pointer_right = "└──" if node.right is not None: @@ -679,7 +880,7 @@ class BinarySearchTree(object): if self.root is None: return "" - ret = f'{self.root.value}' + ret = f"{self.root.value}" pointer_right = "└──" if self.root.right is None: pointer_left = "└──" @@ -687,13 +888,13 @@ class BinarySearchTree(object): pointer_left = "├──" ret += self.repr_traverse( - '', pointer_left, self.root.left, self.root.left is not None + "", pointer_left, self.root.left, self.root.left is not None ) - ret += self.repr_traverse('', pointer_right, self.root.right, False) + ret += self.repr_traverse("", pointer_right, self.root.right, False) return ret -if __name__ == '__main__': +if __name__ == "__main__": import doctest doctest.testmod()