X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=collect%2Fbst.py;h=d39419494d3f482712f17e13a5ff6ce1e7c2ebcf;hb=532df2c5b57c7517dfb3dddd8c1358fbadf8baf3;hp=d3231eecaa22060d0fc47aca073e9e4f98a09310;hpb=791755ededfc0de84a81b4c7313830efbd27bf8c;p=python_utils.git diff --git a/collect/bst.py b/collect/bst.py index d3231ee..d394194 100644 --- a/collect/bst.py +++ b/collect/bst.py @@ -1,16 +1,24 @@ #!/usr/bin/env python3 -from typing import Any, List, Optional +# © Copyright 2021-2022, Scott Gasch + +"""Binary search tree.""" + +from typing import Any, Generator, List, Optional class Node(object): def __init__(self, value: Any) -> None: - self.left = None - self.right = None + """ + Note: value can be anything as long as it is comparable. + Check out @functools.total_ordering. + """ + self.left: Optional[Node] = None + self.right: Optional[Node] = None self.value = value -class BinaryTree(object): +class BinarySearchTree(object): def __init__(self): self.root = None self.count = 0 @@ -23,7 +31,7 @@ class BinaryTree(object): """ Insert something into the tree. - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> t.insert(10) >>> t.insert(20) >>> t.insert(5) @@ -60,7 +68,7 @@ class BinaryTree(object): Find an item in the tree and return its Node. Returns None if the item is not in the tree. - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> t[99] >>> t.insert(10) @@ -80,28 +88,77 @@ class BinaryTree(object): """Find helper""" if value == node.value: return node - elif (value < node.value and node.left is not None): + elif value < node.value and node.left is not None: return self._find(value, node.left) - else: - assert value > node.value - if node.right is not None: - return self._find(value, node.right) + elif value > node.value and node.right is not None: + return self._find(value, node.right) return None + def _parent_path(self, current: Optional[Node], target: Node) -> List[Optional[Node]]: + if current is None: + return [None] + ret: List[Optional[Node]] = [current] + if target.value == current.value: + return ret + elif target.value < current.value: + ret.extend(self._parent_path(current.left, target)) + return ret + else: + assert target.value > current.value + ret.extend(self._parent_path(current.right, target)) + return ret + + def parent_path(self, node: Node) -> List[Optional[Node]]: + """Return a list of nodes representing the path from + the tree's root to the node argument. If the node does + not exist in the tree for some reason, the last element + on the path will be None but the path will indicate the + ancestor path of that node were it inserted. + + >>> t = BinarySearchTree() + >>> t.insert(50) + >>> t.insert(75) + >>> t.insert(25) + >>> t.insert(12) + >>> t.insert(33) + >>> t.insert(4) + >>> t.insert(88) + >>> t + 50 + ├──25 + │ ├──12 + │ │ └──4 + │ └──33 + └──75 + └──88 + + >>> n = t[4] + >>> for x in t.parent_path(n): + ... print(x.value) + 50 + 25 + 12 + 4 + + >>> del t[4] + >>> for x in t.parent_path(n): + ... if x is not None: + ... print(x.value) + ... else: + ... print(x) + 50 + 25 + 12 + None + + """ + return self._parent_path(self.root, node) + def __delitem__(self, value: Any) -> bool: """ Delete an item from the tree and preserve the BST property. - 50 - / \ - 25 75 - / / \ - 22 66 85 - / - 13 - - - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> t.insert(50) >>> t.insert(75) >>> t.insert(25) @@ -109,6 +166,14 @@ class BinaryTree(object): >>> t.insert(22) >>> t.insert(13) >>> t.insert(85) + >>> t + 50 + ├──25 + │ └──22 + │ └──13 + └──75 + ├──66 + └──85 >>> for value in t.iterate_inorder(): ... print(value) @@ -149,6 +214,11 @@ class BinaryTree(object): 50 66 85 + >>> t + 50 + ├──25 + └──85 + └──66 >>> t.__delitem__(99) False @@ -216,7 +286,7 @@ class BinaryTree(object): """ Returns the count of items in the tree. - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> len(t) 0 >>> t.insert(50) @@ -270,7 +340,7 @@ class BinaryTree(object): """ Yield the tree's items in a preorder traversal sequence. - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> t.insert(50) >>> t.insert(75) >>> t.insert(25) @@ -295,18 +365,28 @@ class BinaryTree(object): """ Yield the tree's items in a preorder traversal sequence. - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> t.insert(50) >>> t.insert(75) >>> t.insert(25) >>> t.insert(66) >>> t.insert(22) >>> t.insert(13) + >>> t.insert(24) + >>> t + 50 + ├──25 + │ └──22 + │ ├──13 + │ └──24 + └──75 + └──66 >>> for value in t.iterate_inorder(): ... print(value) 13 22 + 24 25 50 66 @@ -320,7 +400,7 @@ class BinaryTree(object): """ Yield the tree's items in a preorder traversal sequence. - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> t.insert(50) >>> t.insert(75) >>> t.insert(25) @@ -353,7 +433,7 @@ class BinaryTree(object): """ Iterate only the leaf nodes in the tree. - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> t.insert(50) >>> t.insert(75) >>> t.insert(25) @@ -380,11 +460,11 @@ class BinaryTree(object): if node.right is not None: yield from self._iterate_by_depth(node.right, depth - 1) - def iterate_nodes_by_depth(self, depth: int): + def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]: """ Iterate only the leaf nodes in the tree. - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> t.insert(50) >>> t.insert(75) >>> t.insert(25) @@ -405,6 +485,54 @@ class BinaryTree(object): if self.root is not None: yield from self._iterate_by_depth(self.root, depth) + def get_next_node(self, node: Node) -> Node: + """ + Given a tree node, get the next greater node in the tree. + + >>> t = BinarySearchTree() + >>> t.insert(50) + >>> t.insert(75) + >>> t.insert(25) + >>> t.insert(66) + >>> t.insert(22) + >>> t.insert(13) + >>> t.insert(23) + >>> t + 50 + ├──25 + │ └──22 + │ ├──13 + │ └──23 + └──75 + └──66 + + >>> n = t[23] + >>> t.get_next_node(n).value + 25 + + >>> n = t[50] + >>> t.get_next_node(n).value + 66 + + """ + if node.right is not None: + x = node.right + while x.left is not None: + x = x.left + return x + + path = self.parent_path(node) + assert path[-1] is not None + assert path[-1] == node + path = path[:-1] + path.reverse() + for ancestor in path: + assert ancestor is not None + if node != ancestor.right: + return ancestor + node = ancestor + raise Exception() + def _depth(self, node: Node, sofar: int) -> int: depth_left = sofar + 1 depth_right = sofar + 1 @@ -419,7 +547,7 @@ class BinaryTree(object): Returns the max height (depth) of the tree in plies (edge distance from root). - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> t.depth() 0 @@ -448,9 +576,15 @@ class BinaryTree(object): def height(self): return self.depth() - def repr_traverse(self, padding: str, pointer: str, node: Node, has_right_sibling: bool) -> str: + def repr_traverse( + self, + padding: str, + pointer: str, + node: Optional[Node], + has_right_sibling: bool, + ) -> str: if node is not None: - self.viz += f'\n{padding}{pointer}{node.value}' + viz = f'\n{padding}{pointer}{node.value}' if has_right_sibling: padding += "│ " else: @@ -462,14 +596,16 @@ class BinaryTree(object): else: pointer_left = "└──" - self.repr_traverse(padding, pointer_left, node.left, node.right is not None) - self.repr_traverse(padding, pointer_right, node.right, False) + viz += self.repr_traverse(padding, pointer_left, node.left, node.right is not None) + viz += self.repr_traverse(padding, pointer_right, node.right, False) + return viz + return "" def __repr__(self): """ Draw the tree in ASCII. - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> t.insert(50) >>> t.insert(25) >>> t.insert(75) @@ -489,18 +625,19 @@ class BinaryTree(object): if self.root is None: return "" - self.viz = f'{self.root.value}' + ret = f'{self.root.value}' pointer_right = "└──" if self.root.right is None: pointer_left = "└──" else: pointer_left = "├──" - self.repr_traverse('', pointer_left, self.root.left, self.root.left is not None) - self.repr_traverse('', pointer_right, self.root.right, False) - return self.viz + ret += self.repr_traverse('', pointer_left, self.root.left, self.root.left is not None) + ret += self.repr_traverse('', pointer_right, self.root.right, False) + return ret if __name__ == '__main__': import doctest + doctest.testmod()