X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=collect%2Fbst.py;h=72a3b7738b981878b9b08eaba67bca2b33314f4f;hb=822454f580c1ff9eb207b8da46cdfae24e30cde1;hp=d3231eecaa22060d0fc47aca073e9e4f98a09310;hpb=791755ededfc0de84a81b4c7313830efbd27bf8c;p=python_utils.git diff --git a/collect/bst.py b/collect/bst.py index d3231ee..72a3b77 100644 --- a/collect/bst.py +++ b/collect/bst.py @@ -1,16 +1,20 @@ #!/usr/bin/env python3 -from typing import Any, List, Optional +from typing import Any, Optional, List class Node(object): def __init__(self, value: Any) -> None: + """ + Note: value can be anything as long as it is comparable. + Check out @functools.total_ordering. + """ self.left = None self.right = None self.value = value -class BinaryTree(object): +class BinarySearchTree(object): def __init__(self): self.root = None self.count = 0 @@ -23,7 +27,7 @@ class BinaryTree(object): """ Insert something into the tree. - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> t.insert(10) >>> t.insert(20) >>> t.insert(5) @@ -60,7 +64,7 @@ class BinaryTree(object): Find an item in the tree and return its Node. Returns None if the item is not in the tree. - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> t[99] >>> t.insert(10) @@ -82,26 +86,75 @@ class BinaryTree(object): return node elif (value < node.value and node.left is not None): return self._find(value, node.left) - else: - assert value > node.value - if node.right is not None: - return self._find(value, node.right) + elif (value > node.value and node.right is not None): + return self._find(value, node.right) return None + def _parent_path(self, current: Node, target: Node): + if current is None: + return [None] + ret = [current] + if target.value == current.value: + return ret + elif target.value < current.value: + ret.extend(self._parent_path(current.left, target)) + return ret + else: + assert target.value > current.value + ret.extend(self._parent_path(current.right, target)) + return ret + + def parent_path(self, node: Node) -> Optional[List[Node]]: + """Return a list of nodes representing the path from + the tree's root to the node argument. If the node does + not exist in the tree for some reason, the last element + on the path will be None but the path will indicate the + ancestor path of that node were it inserted. + + >>> t = BinarySearchTree() + >>> t.insert(50) + >>> t.insert(75) + >>> t.insert(25) + >>> t.insert(12) + >>> t.insert(33) + >>> t.insert(4) + >>> t.insert(88) + >>> t + 50 + ├──25 + │ ├──12 + │ │ └──4 + │ └──33 + └──75 + └──88 + + >>> n = t[4] + >>> for x in t.parent_path(n): + ... print(x.value) + 50 + 25 + 12 + 4 + + >>> del t[4] + >>> for x in t.parent_path(n): + ... if x is not None: + ... print(x.value) + ... else: + ... print(x) + 50 + 25 + 12 + None + + """ + return self._parent_path(self.root, node) + def __delitem__(self, value: Any) -> bool: """ Delete an item from the tree and preserve the BST property. - 50 - / \ - 25 75 - / / \ - 22 66 85 - / - 13 - - - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> t.insert(50) >>> t.insert(75) >>> t.insert(25) @@ -109,6 +162,14 @@ class BinaryTree(object): >>> t.insert(22) >>> t.insert(13) >>> t.insert(85) + >>> t + 50 + ├──25 + │ └──22 + │ └──13 + └──75 + ├──66 + └──85 >>> for value in t.iterate_inorder(): ... print(value) @@ -149,6 +210,11 @@ class BinaryTree(object): 50 66 85 + >>> t + 50 + ├──25 + └──85 + └──66 >>> t.__delitem__(99) False @@ -216,7 +282,7 @@ class BinaryTree(object): """ Returns the count of items in the tree. - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> len(t) 0 >>> t.insert(50) @@ -270,7 +336,7 @@ class BinaryTree(object): """ Yield the tree's items in a preorder traversal sequence. - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> t.insert(50) >>> t.insert(75) >>> t.insert(25) @@ -295,18 +361,28 @@ class BinaryTree(object): """ Yield the tree's items in a preorder traversal sequence. - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> t.insert(50) >>> t.insert(75) >>> t.insert(25) >>> t.insert(66) >>> t.insert(22) >>> t.insert(13) + >>> t.insert(24) + >>> t + 50 + ├──25 + │ └──22 + │ ├──13 + │ └──24 + └──75 + └──66 >>> for value in t.iterate_inorder(): ... print(value) 13 22 + 24 25 50 66 @@ -320,7 +396,7 @@ class BinaryTree(object): """ Yield the tree's items in a preorder traversal sequence. - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> t.insert(50) >>> t.insert(75) >>> t.insert(25) @@ -353,7 +429,7 @@ class BinaryTree(object): """ Iterate only the leaf nodes in the tree. - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> t.insert(50) >>> t.insert(75) >>> t.insert(25) @@ -384,7 +460,7 @@ class BinaryTree(object): """ Iterate only the leaf nodes in the tree. - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> t.insert(50) >>> t.insert(75) >>> t.insert(25) @@ -405,6 +481,51 @@ class BinaryTree(object): if self.root is not None: yield from self._iterate_by_depth(self.root, depth) + def get_next_node(self, node: Node) -> Node: + """ + Given a tree node, get the next greater node in the tree. + + >>> t = BinarySearchTree() + >>> t.insert(50) + >>> t.insert(75) + >>> t.insert(25) + >>> t.insert(66) + >>> t.insert(22) + >>> t.insert(13) + >>> t.insert(23) + >>> t + 50 + ├──25 + │ └──22 + │ ├──13 + │ └──23 + └──75 + └──66 + + >>> n = t[23] + >>> t.get_next_node(n).value + 25 + + >>> n = t[50] + >>> t.get_next_node(n).value + 66 + + """ + if node.right is not None: + x = node.right + while x.left is not None: + x = x.left + return x + + path = self.parent_path(node) + assert path[-1] == node + path = path[:-1] + path.reverse() + for ancestor in path: + if node != ancestor.right: + return ancestor + node = ancestor + def _depth(self, node: Node, sofar: int) -> int: depth_left = sofar + 1 depth_right = sofar + 1 @@ -419,7 +540,7 @@ class BinaryTree(object): Returns the max height (depth) of the tree in plies (edge distance from root). - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> t.depth() 0 @@ -450,7 +571,7 @@ class BinaryTree(object): def repr_traverse(self, padding: str, pointer: str, node: Node, has_right_sibling: bool) -> str: if node is not None: - self.viz += f'\n{padding}{pointer}{node.value}' + viz = f'\n{padding}{pointer}{node.value}' if has_right_sibling: padding += "│ " else: @@ -462,14 +583,16 @@ class BinaryTree(object): else: pointer_left = "└──" - self.repr_traverse(padding, pointer_left, node.left, node.right is not None) - self.repr_traverse(padding, pointer_right, node.right, False) + viz += self.repr_traverse(padding, pointer_left, node.left, node.right is not None) + viz += self.repr_traverse(padding, pointer_right, node.right, False) + return viz + return "" def __repr__(self): """ Draw the tree in ASCII. - >>> t = BinaryTree() + >>> t = BinarySearchTree() >>> t.insert(50) >>> t.insert(25) >>> t.insert(75) @@ -489,16 +612,16 @@ class BinaryTree(object): if self.root is None: return "" - self.viz = f'{self.root.value}' + ret = f'{self.root.value}' pointer_right = "└──" if self.root.right is None: pointer_left = "└──" else: pointer_left = "├──" - self.repr_traverse('', pointer_left, self.root.left, self.root.left is not None) - self.repr_traverse('', pointer_right, self.root.right, False) - return self.viz + ret += self.repr_traverse('', pointer_left, self.root.left, self.root.left is not None) + ret += self.repr_traverse('', pointer_right, self.root.right, False) + return ret if __name__ == '__main__':