X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=collect%2Fbst.py;h=d39419494d3f482712f17e13a5ff6ce1e7c2ebcf;hb=532df2c5b57c7517dfb3dddd8c1358fbadf8baf3;hp=94570f49be8490b4656d2b4ea12185a44c636212;hpb=986d5f7ada15e56019518db43d07b76f94468e1a;p=python_utils.git diff --git a/collect/bst.py b/collect/bst.py index 94570f4..d394194 100644 --- a/collect/bst.py +++ b/collect/bst.py @@ -1,16 +1,24 @@ #!/usr/bin/env python3 -from typing import Any, Optional +# © Copyright 2021-2022, Scott Gasch + +"""Binary search tree.""" + +from typing import Any, Generator, List, Optional class Node(object): def __init__(self, value: Any) -> None: - self.left = None - self.right = None + """ + Note: value can be anything as long as it is comparable. + Check out @functools.total_ordering. + """ + self.left: Optional[Node] = None + self.right: Optional[Node] = None self.value = value -class BinaryTree(object): +class BinarySearchTree(object): def __init__(self): self.root = None self.count = 0 @@ -23,7 +31,7 @@ class BinaryTree(object): """ Insert something into the tree. - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> t.insert(10) >>> t.insert(20) >>> t.insert(5) @@ -60,7 +68,7 @@ class BinaryTree(object): Find an item in the tree and return its Node. Returns None if the item is not in the tree. - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> t[99] >>> t.insert(10) @@ -80,28 +88,77 @@ class BinaryTree(object): """Find helper""" if value == node.value: return node - elif (value < node.value and node.left is not None): + elif value < node.value and node.left is not None: return self._find(value, node.left) - else: - assert value > node.value - if node.right is not None: - return self._find(value, node.right) + elif value > node.value and node.right is not None: + return self._find(value, node.right) return None + def _parent_path(self, current: Optional[Node], target: Node) -> List[Optional[Node]]: + if current is None: + return [None] + ret: List[Optional[Node]] = [current] + if target.value == current.value: + return ret + elif target.value < current.value: + ret.extend(self._parent_path(current.left, target)) + return ret + else: + assert target.value > current.value + ret.extend(self._parent_path(current.right, target)) + return ret + + def parent_path(self, node: Node) -> List[Optional[Node]]: + """Return a list of nodes representing the path from + the tree's root to the node argument. If the node does + not exist in the tree for some reason, the last element + on the path will be None but the path will indicate the + ancestor path of that node were it inserted. + + >>> t = BinarySearchTree() + >>> t.insert(50) + >>> t.insert(75) + >>> t.insert(25) + >>> t.insert(12) + >>> t.insert(33) + >>> t.insert(4) + >>> t.insert(88) + >>> t + 50 + ├──25 + │ ├──12 + │ │ └──4 + │ └──33 + └──75 + └──88 + + >>> n = t[4] + >>> for x in t.parent_path(n): + ... print(x.value) + 50 + 25 + 12 + 4 + + >>> del t[4] + >>> for x in t.parent_path(n): + ... if x is not None: + ... print(x.value) + ... else: + ... print(x) + 50 + 25 + 12 + None + + """ + return self._parent_path(self.root, node) + def __delitem__(self, value: Any) -> bool: """ Delete an item from the tree and preserve the BST property. - 50 - / \ - 25 75 - / / \ - 22 66 85 - / - 13 - - - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> t.insert(50) >>> t.insert(75) >>> t.insert(25) @@ -109,6 +166,14 @@ class BinaryTree(object): >>> t.insert(22) >>> t.insert(13) >>> t.insert(85) + >>> t + 50 + ├──25 + │ └──22 + │ └──13 + └──75 + ├──66 + └──85 >>> for value in t.iterate_inorder(): ... print(value) @@ -120,8 +185,8 @@ class BinaryTree(object): 75 85 - >>> t.__delitem__(22) - True + >>> del t[22] # Note: bool result is discarded + >>> for value in t.iterate_inorder(): ... print(value) 13 @@ -149,6 +214,11 @@ class BinaryTree(object): 50 66 85 + >>> t + 50 + ├──25 + └──85 + └──66 >>> t.__delitem__(99) False @@ -216,7 +286,7 @@ class BinaryTree(object): """ Returns the count of items in the tree. - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> len(t) 0 >>> t.insert(50) @@ -270,7 +340,7 @@ class BinaryTree(object): """ Yield the tree's items in a preorder traversal sequence. - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> t.insert(50) >>> t.insert(75) >>> t.insert(25) @@ -295,18 +365,28 @@ class BinaryTree(object): """ Yield the tree's items in a preorder traversal sequence. - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> t.insert(50) >>> t.insert(75) >>> t.insert(25) >>> t.insert(66) >>> t.insert(22) >>> t.insert(13) + >>> t.insert(24) + >>> t + 50 + ├──25 + │ └──22 + │ ├──13 + │ └──24 + └──75 + └──66 >>> for value in t.iterate_inorder(): ... print(value) 13 22 + 24 25 50 66 @@ -320,7 +400,7 @@ class BinaryTree(object): """ Yield the tree's items in a preorder traversal sequence. - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> t.insert(50) >>> t.insert(75) >>> t.insert(25) @@ -353,7 +433,7 @@ class BinaryTree(object): """ Iterate only the leaf nodes in the tree. - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> t.insert(50) >>> t.insert(75) >>> t.insert(25) @@ -380,11 +460,11 @@ class BinaryTree(object): if node.right is not None: yield from self._iterate_by_depth(node.right, depth - 1) - def iterate_nodes_by_depth(self, depth: int): + def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]: """ Iterate only the leaf nodes in the tree. - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> t.insert(50) >>> t.insert(75) >>> t.insert(25) @@ -405,6 +485,54 @@ class BinaryTree(object): if self.root is not None: yield from self._iterate_by_depth(self.root, depth) + def get_next_node(self, node: Node) -> Node: + """ + Given a tree node, get the next greater node in the tree. + + >>> t = BinarySearchTree() + >>> t.insert(50) + >>> t.insert(75) + >>> t.insert(25) + >>> t.insert(66) + >>> t.insert(22) + >>> t.insert(13) + >>> t.insert(23) + >>> t + 50 + ├──25 + │ └──22 + │ ├──13 + │ └──23 + └──75 + └──66 + + >>> n = t[23] + >>> t.get_next_node(n).value + 25 + + >>> n = t[50] + >>> t.get_next_node(n).value + 66 + + """ + if node.right is not None: + x = node.right + while x.left is not None: + x = x.left + return x + + path = self.parent_path(node) + assert path[-1] is not None + assert path[-1] == node + path = path[:-1] + path.reverse() + for ancestor in path: + assert ancestor is not None + if node != ancestor.right: + return ancestor + node = ancestor + raise Exception() + def _depth(self, node: Node, sofar: int) -> int: depth_left = sofar + 1 depth_right = sofar + 1 @@ -419,7 +547,7 @@ class BinaryTree(object): Returns the max height (depth) of the tree in plies (edge distance from root). - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> t.depth() 0 @@ -448,7 +576,68 @@ class BinaryTree(object): def height(self): return self.depth() + def repr_traverse( + self, + padding: str, + pointer: str, + node: Optional[Node], + has_right_sibling: bool, + ) -> str: + if node is not None: + viz = f'\n{padding}{pointer}{node.value}' + if has_right_sibling: + padding += "│ " + else: + padding += ' ' + + pointer_right = "└──" + if node.right is not None: + pointer_left = "├──" + else: + pointer_left = "└──" + + viz += self.repr_traverse(padding, pointer_left, node.left, node.right is not None) + viz += self.repr_traverse(padding, pointer_right, node.right, False) + return viz + return "" + + def __repr__(self): + """ + Draw the tree in ASCII. + + >>> t = BinarySearchTree() + >>> t.insert(50) + >>> t.insert(25) + >>> t.insert(75) + >>> t.insert(12) + >>> t.insert(33) + >>> t.insert(88) + >>> t.insert(55) + >>> t + 50 + ├──25 + │ ├──12 + │ └──33 + └──75 + ├──55 + └──88 + """ + if self.root is None: + return "" + + ret = f'{self.root.value}' + pointer_right = "└──" + if self.root.right is None: + pointer_left = "└──" + else: + pointer_left = "├──" + + ret += self.repr_traverse('', pointer_left, self.root.left, self.root.left is not None) + ret += self.repr_traverse('', pointer_right, self.root.right, False) + return ret + if __name__ == '__main__': import doctest + doctest.testmod()