#!/usr/bin/env python3
-from typing import Any, Optional, List
+# © Copyright 2021-2022, Scott Gasch
+
+"""Binary search tree."""
+
+from typing import Any, Generator, List, Optional
class Node(object):
Note: value can be anything as long as it is comparable.
Check out @functools.total_ordering.
"""
- self.left = None
- self.right = None
+ self.left: Optional[Node] = None
+ self.right: Optional[Node] = None
self.value = value
"""Find helper"""
if value == node.value:
return node
- elif (value < node.value and node.left is not None):
+ elif value < node.value and node.left is not None:
return self._find(value, node.left)
- elif (value > node.value and node.right is not None):
+ elif value > node.value and node.right is not None:
return self._find(value, node.right)
return None
- def _parent_path(self, current: Node, target: Node):
+ def _parent_path(self, current: Optional[Node], target: Node) -> List[Optional[Node]]:
if current is None:
return [None]
- ret = [current]
+ ret: List[Optional[Node]] = [current]
if target.value == current.value:
return ret
elif target.value < current.value:
ret.extend(self._parent_path(current.right, target))
return ret
- def parent_path(self, node: Node) -> Optional[List[Node]]:
+ def parent_path(self, node: Node) -> List[Optional[Node]]:
"""Return a list of nodes representing the path from
the tree's root to the node argument. If the node does
not exist in the tree for some reason, the last element
if node.right is not None:
yield from self._iterate_by_depth(node.right, depth - 1)
- def iterate_nodes_by_depth(self, depth: int):
+ def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
"""
Iterate only the leaf nodes in the tree.
return x
path = self.parent_path(node)
+ assert path[-1] is not None
assert path[-1] == node
path = path[:-1]
path.reverse()
for ancestor in path:
+ assert ancestor is not None
if node != ancestor.right:
return ancestor
node = ancestor
+ raise Exception()
def _depth(self, node: Node, sofar: int) -> int:
depth_left = sofar + 1
def height(self):
return self.depth()
- def repr_traverse(self, padding: str, pointer: str, node: Node, has_right_sibling: bool) -> str:
+ def repr_traverse(
+ self,
+ padding: str,
+ pointer: str,
+ node: Optional[Node],
+ has_right_sibling: bool,
+ ) -> str:
if node is not None:
viz = f'\n{padding}{pointer}{node.value}'
if has_right_sibling:
if __name__ == '__main__':
import doctest
+
doctest.testmod()