X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=collect%2Fbst.py;fp=collect%2Fbst.py;h=9d6525946e8131728896d86f3400c38c5ba528e7;hb=6ba90a1f30f1c0cf4df12fcd0c62181f29bc3668;hp=72a3b7738b981878b9b08eaba67bca2b33314f4f;hpb=31c81f6539969a5eba864d3305f9fb7bf716a367;p=python_utils.git diff --git a/collect/bst.py b/collect/bst.py index 72a3b77..9d65259 100644 --- a/collect/bst.py +++ b/collect/bst.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -from typing import Any, Optional, List +from typing import Any, Generator, List, Optional class Node(object): @@ -9,8 +9,8 @@ class Node(object): Note: value can be anything as long as it is comparable. Check out @functools.total_ordering. """ - self.left = None - self.right = None + self.left: Optional[Node] = None + self.right: Optional[Node] = None self.value = value @@ -84,16 +84,18 @@ class BinarySearchTree(object): """Find helper""" if value == node.value: return node - elif (value < node.value and node.left is not None): + elif value < node.value and node.left is not None: return self._find(value, node.left) - elif (value > node.value and node.right is not None): + elif value > node.value and node.right is not None: return self._find(value, node.right) return None - def _parent_path(self, current: Node, target: Node): + def _parent_path( + self, current: Optional[Node], target: Node + ) -> List[Optional[Node]]: if current is None: return [None] - ret = [current] + ret: List[Optional[Node]] = [current] if target.value == current.value: return ret elif target.value < current.value: @@ -104,7 +106,7 @@ class BinarySearchTree(object): ret.extend(self._parent_path(current.right, target)) return ret - def parent_path(self, node: Node) -> Optional[List[Node]]: + def parent_path(self, node: Node) -> List[Optional[Node]]: """Return a list of nodes representing the path from the tree's root to the node argument. If the node does not exist in the tree for some reason, the last element @@ -456,7 +458,7 @@ class BinarySearchTree(object): if node.right is not None: yield from self._iterate_by_depth(node.right, depth - 1) - def iterate_nodes_by_depth(self, depth: int): + def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]: """ Iterate only the leaf nodes in the tree. @@ -518,13 +520,16 @@ class BinarySearchTree(object): return x path = self.parent_path(node) + assert path[-1] assert path[-1] == node path = path[:-1] path.reverse() for ancestor in path: + assert ancestor if node != ancestor.right: return ancestor node = ancestor + raise Exception() def _depth(self, node: Node, sofar: int) -> int: depth_left = sofar + 1 @@ -569,7 +574,9 @@ class BinarySearchTree(object): def height(self): return self.depth() - def repr_traverse(self, padding: str, pointer: str, node: Node, has_right_sibling: bool) -> str: + def repr_traverse( + self, padding: str, pointer: str, node: Optional[Node], has_right_sibling: bool + ) -> str: if node is not None: viz = f'\n{padding}{pointer}{node.value}' if has_right_sibling: @@ -583,7 +590,9 @@ class BinarySearchTree(object): else: pointer_left = "└──" - viz += self.repr_traverse(padding, pointer_left, node.left, node.right is not None) + viz += self.repr_traverse( + padding, pointer_left, node.left, node.right is not None + ) viz += self.repr_traverse(padding, pointer_right, node.right, False) return viz return "" @@ -619,11 +628,14 @@ class BinarySearchTree(object): else: pointer_left = "├──" - ret += self.repr_traverse('', pointer_left, self.root.left, self.root.left is not None) + ret += self.repr_traverse( + '', pointer_left, self.root.left, self.root.left is not None + ) ret += self.repr_traverse('', pointer_right, self.root.right, False) return ret if __name__ == '__main__': import doctest + doctest.testmod()