X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=collect%2Fbst.py;h=d39419494d3f482712f17e13a5ff6ce1e7c2ebcf;hb=532df2c5b57c7517dfb3dddd8c1358fbadf8baf3;hp=b4d25b34a627797660362a16d6430fdcd6d7eceb;hpb=b0bde5bef4a19382136112196b238088641738d5;p=python_utils.git diff --git a/collect/bst.py b/collect/bst.py index b4d25b3..d394194 100644 --- a/collect/bst.py +++ b/collect/bst.py @@ -1,12 +1,20 @@ #!/usr/bin/env python3 -from typing import Any, Optional, List +# © Copyright 2021-2022, Scott Gasch + +"""Binary search tree.""" + +from typing import Any, Generator, List, Optional class Node(object): def __init__(self, value: Any) -> None: - self.left = None - self.right = None + """ + Note: value can be anything as long as it is comparable. + Check out @functools.total_ordering. + """ + self.left: Optional[Node] = None + self.right: Optional[Node] = None self.value = value @@ -80,18 +88,16 @@ class BinarySearchTree(object): """Find helper""" if value == node.value: return node - elif (value < node.value and node.left is not None): + elif value < node.value and node.left is not None: return self._find(value, node.left) - else: - assert value > node.value - if node.right is not None: - return self._find(value, node.right) + elif value > node.value and node.right is not None: + return self._find(value, node.right) return None - def _parent_path(self, current: Node, target: Node): + def _parent_path(self, current: Optional[Node], target: Node) -> List[Optional[Node]]: if current is None: return [None] - ret = [current] + ret: List[Optional[Node]] = [current] if target.value == current.value: return ret elif target.value < current.value: @@ -102,9 +108,12 @@ class BinarySearchTree(object): ret.extend(self._parent_path(current.right, target)) return ret - def parent_path(self, node: Node) -> Optional[List[Node]]: + def parent_path(self, node: Node) -> List[Optional[Node]]: """Return a list of nodes representing the path from - the tree's root to the node argument. + the tree's root to the node argument. If the node does + not exist in the tree for some reason, the last element + on the path will be None but the path will indicate the + ancestor path of that node were it inserted. >>> t = BinarySearchTree() >>> t.insert(50) @@ -131,6 +140,17 @@ class BinarySearchTree(object): 12 4 + >>> del t[4] + >>> for x in t.parent_path(n): + ... if x is not None: + ... print(x.value) + ... else: + ... print(x) + 50 + 25 + 12 + None + """ return self._parent_path(self.root, node) @@ -138,15 +158,6 @@ class BinarySearchTree(object): """ Delete an item from the tree and preserve the BST property. - 50 - / \ - 25 75 - / / \ - 22 66 85 - / - 13 - - >>> t = BinarySearchTree() >>> t.insert(50) >>> t.insert(75) @@ -155,6 +166,14 @@ class BinarySearchTree(object): >>> t.insert(22) >>> t.insert(13) >>> t.insert(85) + >>> t + 50 + ├──25 + │ └──22 + │ └──13 + └──75 + ├──66 + └──85 >>> for value in t.iterate_inorder(): ... print(value) @@ -195,6 +214,11 @@ class BinarySearchTree(object): 50 66 85 + >>> t + 50 + ├──25 + └──85 + └──66 >>> t.__delitem__(99) False @@ -436,7 +460,7 @@ class BinarySearchTree(object): if node.right is not None: yield from self._iterate_by_depth(node.right, depth - 1) - def iterate_nodes_by_depth(self, depth: int): + def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]: """ Iterate only the leaf nodes in the tree. @@ -498,13 +522,16 @@ class BinarySearchTree(object): return x path = self.parent_path(node) + assert path[-1] is not None assert path[-1] == node path = path[:-1] path.reverse() for ancestor in path: + assert ancestor is not None if node != ancestor.right: return ancestor node = ancestor + raise Exception() def _depth(self, node: Node, sofar: int) -> int: depth_left = sofar + 1 @@ -549,9 +576,15 @@ class BinarySearchTree(object): def height(self): return self.depth() - def repr_traverse(self, padding: str, pointer: str, node: Node, has_right_sibling: bool) -> str: + def repr_traverse( + self, + padding: str, + pointer: str, + node: Optional[Node], + has_right_sibling: bool, + ) -> str: if node is not None: - self.viz += f'\n{padding}{pointer}{node.value}' + viz = f'\n{padding}{pointer}{node.value}' if has_right_sibling: padding += "│ " else: @@ -563,8 +596,10 @@ class BinarySearchTree(object): else: pointer_left = "└──" - self.repr_traverse(padding, pointer_left, node.left, node.right is not None) - self.repr_traverse(padding, pointer_right, node.right, False) + viz += self.repr_traverse(padding, pointer_left, node.left, node.right is not None) + viz += self.repr_traverse(padding, pointer_right, node.right, False) + return viz + return "" def __repr__(self): """ @@ -590,18 +625,19 @@ class BinarySearchTree(object): if self.root is None: return "" - self.viz = f'{self.root.value}' + ret = f'{self.root.value}' pointer_right = "└──" if self.root.right is None: pointer_left = "└──" else: pointer_left = "├──" - self.repr_traverse('', pointer_left, self.root.left, self.root.left is not None) - self.repr_traverse('', pointer_right, self.root.right, False) - return self.viz + ret += self.repr_traverse('', pointer_left, self.root.left, self.root.left is not None) + ret += self.repr_traverse('', pointer_right, self.root.right, False) + return ret if __name__ == '__main__': import doctest + doctest.testmod()