X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=collect%2Fbst.py;h=712683eb59ea3c38939ebf4f5a4cb31c191199f8;hb=5fd3697843f4d03e4bb65a0346764805aabc2fde;hp=72a3b7738b981878b9b08eaba67bca2b33314f4f;hpb=fa4298fa508e00759565c246aef423ba28fedf31;p=python_utils.git diff --git a/collect/bst.py b/collect/bst.py index 72a3b77..712683e 100644 --- a/collect/bst.py +++ b/collect/bst.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -from typing import Any, Optional, List +from typing import Any, Generator, List, Optional class Node(object): @@ -9,8 +9,8 @@ class Node(object): Note: value can be anything as long as it is comparable. Check out @functools.total_ordering. """ - self.left = None - self.right = None + self.left: Optional[Node] = None + self.right: Optional[Node] = None self.value = value @@ -84,16 +84,16 @@ class BinarySearchTree(object): """Find helper""" if value == node.value: return node - elif (value < node.value and node.left is not None): + elif value < node.value and node.left is not None: return self._find(value, node.left) - elif (value > node.value and node.right is not None): + elif value > node.value and node.right is not None: return self._find(value, node.right) return None - def _parent_path(self, current: Node, target: Node): + def _parent_path(self, current: Optional[Node], target: Node) -> List[Optional[Node]]: if current is None: return [None] - ret = [current] + ret: List[Optional[Node]] = [current] if target.value == current.value: return ret elif target.value < current.value: @@ -104,7 +104,7 @@ class BinarySearchTree(object): ret.extend(self._parent_path(current.right, target)) return ret - def parent_path(self, node: Node) -> Optional[List[Node]]: + def parent_path(self, node: Node) -> List[Optional[Node]]: """Return a list of nodes representing the path from the tree's root to the node argument. If the node does not exist in the tree for some reason, the last element @@ -456,7 +456,7 @@ class BinarySearchTree(object): if node.right is not None: yield from self._iterate_by_depth(node.right, depth - 1) - def iterate_nodes_by_depth(self, depth: int): + def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]: """ Iterate only the leaf nodes in the tree. @@ -518,13 +518,16 @@ class BinarySearchTree(object): return x path = self.parent_path(node) + assert path[-1] is not None assert path[-1] == node path = path[:-1] path.reverse() for ancestor in path: + assert ancestor is not None if node != ancestor.right: return ancestor node = ancestor + raise Exception() def _depth(self, node: Node, sofar: int) -> int: depth_left = sofar + 1 @@ -569,7 +572,13 @@ class BinarySearchTree(object): def height(self): return self.depth() - def repr_traverse(self, padding: str, pointer: str, node: Node, has_right_sibling: bool) -> str: + def repr_traverse( + self, + padding: str, + pointer: str, + node: Optional[Node], + has_right_sibling: bool, + ) -> str: if node is not None: viz = f'\n{padding}{pointer}{node.value}' if has_right_sibling: @@ -626,4 +635,5 @@ class BinarySearchTree(object): if __name__ == '__main__': import doctest + doctest.testmod()