X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=src%2Fpyutils%2Fcollectionz%2Fbst.py;h=74c328da05f9729034362a74b91ff35d5f14d68e;hb=HEAD;hp=f492dfdab7aa8cea4c8a08a1592bbd93e05aadce;hpb=77513ea630d72318684cf1d0a9198a22f4b547a7;p=pyutils.git diff --git a/src/pyutils/collectionz/bst.py b/src/pyutils/collectionz/bst.py index f492dfd..74c328d 100644 --- a/src/pyutils/collectionz/bst.py +++ b/src/pyutils/collectionz/bst.py @@ -4,22 +4,25 @@ """A binary search tree implementation.""" -from typing import Any, Generator, List, Optional +from typing import Generator, List, Optional +from pyutils.typez.typing import Comparable -class Node(object): - def __init__(self, value: Any) -> None: - """ - A BST node. Note that value can be anything as long as it - is comparable. Check out :meth:`functools.total_ordering` - (https://docs.python.org/3/library/functools.html#functools.total_ordering) + +class Node: + def __init__(self, value: Comparable) -> None: + """A BST node. Just a left and right reference along with a + value. Note that value can be anything as long as it + is :class:`Comparable` with other instances of itself. Args: - value: a reference to the value of the node. + value: a reference to the value of the node. Must be + :class:`Comparable` to other values. + """ self.left: Optional[Node] = None self.right: Optional[Node] = None - self.value = value + self.value: Comparable = value class BinarySearchTree(object): @@ -40,9 +43,9 @@ class BinarySearchTree(object): """This is called immediately _after_ a new node is inserted.""" pass - def insert(self, value: Any) -> None: + def insert(self, value: Comparable) -> None: """ - Insert something into the tree. + Insert something into the tree in :math:`O(log_2 n)` time. Args: value: the value to be inserted. @@ -65,7 +68,7 @@ class BinarySearchTree(object): else: self._insert(value, self.root) - def _insert(self, value: Any, node: Node): + def _insert(self, value: Comparable, node: Node): """Insertion helper""" if value < node.value: if node.left is not None: @@ -82,10 +85,11 @@ class BinarySearchTree(object): self.count += 1 self._on_insert(node, node.right) - def __getitem__(self, value: Any) -> Optional[Node]: + def __getitem__(self, value: Comparable) -> Optional[Node]: """ - Find an item in the tree and return its Node. Returns - None if the item is not in the tree. + Find an item in the tree and return its Node in + :math:`O(log_2 n)` time. Returns None if the item is not in + the tree. >>> t = BinarySearchTree() >>> t[99] @@ -100,19 +104,144 @@ class BinarySearchTree(object): """ if self.root is not None: - return self._find(value, self.root) + return self._find_exact(value, self.root) return None - def _find(self, value: Any, node: Node) -> Optional[Node]: - """Find helper""" - if value == node.value: + def _find_exact(self, target: Comparable, node: Node) -> Optional[Node]: + """Recursively traverse the tree looking for a node with the + target value. Return that node if it exists, otherwise return + None.""" + + if target == node.value: return node - elif value < node.value and node.left is not None: - return self._find(value, node.left) - elif value > node.value and node.right is not None: - return self._find(value, node.right) + elif target < node.value and node.left is not None: + return self._find_exact(target, node.left) + elif target > node.value and node.right is not None: + return self._find_exact(target, node.right) return None + def _find_lowest_node_less_than_or_equal_to( + self, target: Comparable, node: Optional[Node] + ) -> Optional[Node]: + """Find helper that returns the lowest node that is less + than or equal to the target value. Returns None if target is + lower than the lowest node in the tree. + + >>> t = BinarySearchTree() + >>> t.insert(50) + >>> t.insert(75) + >>> t.insert(25) + >>> t.insert(66) + >>> t.insert(22) + >>> t.insert(13) + >>> t.insert(85) + >>> t + 50 + ├──25 + │ └──22 + │ └──13 + └──75 + ├──66 + └──85 + + >>> t._find_lowest_node_less_than_or_equal_to(48, t.root).value + 25 + >>> t._find_lowest_node_less_than_or_equal_to(55, t.root).value + 50 + >>> t._find_lowest_node_less_than_or_equal_to(100, t.root).value + 85 + >>> t._find_lowest_node_less_than_or_equal_to(24, t.root).value + 22 + >>> t._find_lowest_node_less_than_or_equal_to(20, t.root).value + 13 + >>> t._find_lowest_node_less_than_or_equal_to(72, t.root).value + 66 + >>> t._find_lowest_node_less_than_or_equal_to(78, t.root).value + 75 + >>> t._find_lowest_node_less_than_or_equal_to(12, t.root) is None + True + + """ + + if not node: + return None + + if target == node.value: + return node + + elif target > node.value: + if below := self._find_lowest_node_less_than_or_equal_to( + target, node.right + ): + return below + else: + return node + + else: + return self._find_lowest_node_less_than_or_equal_to(target, node.left) + + def _find_lowest_node_greater_than_or_equal_to( + self, target: Comparable, node: Optional[Node] + ) -> Optional[Node]: + """Find helper that returns the lowest node that is greater + than or equal to the target value. Returns None if target is + higher than the greatest node in the tree. + + >>> t = BinarySearchTree() + >>> t.insert(50) + >>> t.insert(75) + >>> t.insert(25) + >>> t.insert(66) + >>> t.insert(22) + >>> t.insert(13) + >>> t.insert(85) + >>> t + 50 + ├──25 + │ └──22 + │ └──13 + └──75 + ├──66 + └──85 + + >>> t._find_lowest_node_greater_than_or_equal_to(48, t.root).value + 50 + >>> t._find_lowest_node_greater_than_or_equal_to(55, t.root).value + 66 + >>> t._find_lowest_node_greater_than_or_equal_to(1, t.root).value + 13 + >>> t._find_lowest_node_greater_than_or_equal_to(24, t.root).value + 25 + >>> t._find_lowest_node_greater_than_or_equal_to(20, t.root).value + 22 + >>> t._find_lowest_node_greater_than_or_equal_to(72, t.root).value + 75 + >>> t._find_lowest_node_greater_than_or_equal_to(78, t.root).value + 85 + >>> t._find_lowest_node_greater_than_or_equal_to(95, t.root) is None + True + + """ + + if not node: + return None + + if target == node.value: + return node + + elif target > node.value: + return self._find_lowest_node_greater_than_or_equal_to(target, node.right) + + # If target < this node's value, either this node is the + # answer or the answer is in this node's left subtree. + else: + if below := self._find_lowest_node_greater_than_or_equal_to( + target, node.left + ): + return below + else: + return node + def _parent_path( self, current: Optional[Node], target: Node ) -> List[Optional[Node]]: @@ -131,14 +260,14 @@ class BinarySearchTree(object): return ret def parent_path(self, node: Node) -> List[Optional[Node]]: - """Get a node's parent path. + """Get a node's parent path in :math:`O(log_2 n)` time. Args: - node: the node to check + node: the node whose parent path should be returned. Returns: a list of nodes representing the path from - the tree's root to the node. + the tree's root to the given node. .. note:: @@ -185,9 +314,10 @@ class BinarySearchTree(object): """ return self._parent_path(self.root, node) - def __delitem__(self, value: Any) -> bool: + def __delitem__(self, value: Comparable) -> bool: """ - Delete an item from the tree and preserve the BST property. + Delete an item from the tree and preserve the BST property in + :math:`O(log_2 n) time`. Args: value: the value of the node to be deleted. @@ -258,6 +388,9 @@ class BinarySearchTree(object): └──85 └──66 + >>> t.__delitem__(85) + True + >>> t.__delitem__(99) False @@ -275,7 +408,7 @@ class BinarySearchTree(object): """This is called just after deleted was deleted from the tree""" pass - def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool: + def _delete(self, value: Comparable, parent: Optional[Node], node: Node) -> bool: """Delete helper""" if node.value == value: @@ -314,14 +447,18 @@ class BinarySearchTree(object): self._on_delete(parent, node) return True - # Node has both a left and right. + # Node has both a left and right; get the successor node + # to this one and put it here (deleting the successor's + # old node). Because these operations are happening only + # in the subtree underneath of node, I'm still calling + # this delete an O(log_2 n) operation in the docs. else: assert node.left is not None and node.right is not None - descendent = node.right - while descendent.left is not None: - descendent = descendent.left - node.value = descendent.value + successor = self.get_next_node(node) + assert successor is not None + node.value = successor.value return self._delete(node.value, node, node.right) + elif value < node.value and node.left is not None: return self._delete(value, node, node.left) elif value > node.value and node.right is not None: @@ -331,7 +468,7 @@ class BinarySearchTree(object): def __len__(self): """ Returns: - The count of items in the tree. + The count of items in the tree in :math:`O(1)` time. >>> t = BinarySearchTree() >>> len(t) @@ -355,7 +492,7 @@ class BinarySearchTree(object): """ return self.count - def __contains__(self, value: Any) -> bool: + def __contains__(self, value: Comparable) -> bool: """ Returns: True if the item is in the tree; False otherwise. @@ -485,7 +622,7 @@ class BinarySearchTree(object): def iterate_leaves(self): """ Returns: - A Gemerator that yielde only the leaf nodes in the + A Generator that yields only the leaf nodes in the tree. >>> t = BinarySearchTree() @@ -545,13 +682,14 @@ class BinarySearchTree(object): if self.root is not None: yield from self._iterate_by_depth(self.root, depth) - def get_next_node(self, node: Node) -> Node: + def get_next_node(self, node: Node) -> Optional[Node]: """ Args: node: the node whose next greater successor is desired Returns: Given a tree node, returns the next greater node in the tree. + If the given node is the greatest node in the tree, returns None. >>> t = BinarySearchTree() >>> t.insert(50) @@ -578,6 +716,10 @@ class BinarySearchTree(object): >>> t.get_next_node(n).value 66 + >>> n = t[75] + >>> t.get_next_node(n) is None + True + """ if node.right is not None: x = node.right @@ -595,7 +737,51 @@ class BinarySearchTree(object): if node != ancestor.right: return ancestor node = ancestor - raise Exception() + return None + + def get_nodes_in_range_inclusive( + self, lower: Comparable, upper: Comparable + ) -> Generator[Node, None, None]: + """ + Args: + lower: the lower bound of the desired range. + upper: the upper bound of the desired range. + + Returns: + Generates a sequence of nodes in the desired range. + + >>> t = BinarySearchTree() + >>> t.insert(50) + >>> t.insert(75) + >>> t.insert(25) + >>> t.insert(66) + >>> t.insert(22) + >>> t.insert(13) + >>> t.insert(23) + >>> t + 50 + ├──25 + │ └──22 + │ ├──13 + │ └──23 + └──75 + └──66 + + >>> for node in t.get_nodes_in_range_inclusive(21, 74): + ... print(node.value) + 22 + 23 + 25 + 50 + 66 + """ + node: Optional[Node] = self._find_lowest_node_greater_than_or_equal_to( + lower, self.root + ) + while node: + if lower <= node.value <= upper: + yield node + node = self.get_next_node(node) def _depth(self, node: Node, sofar: int) -> int: depth_left = sofar + 1 @@ -610,7 +796,7 @@ class BinarySearchTree(object): """ Returns: The max height (depth) of the tree in plies (edge distance - from root). + from root) in :math:`O(log_2 n)` time. >>> t = BinarySearchTree() >>> t.depth()