#!/usr/bin/env python3
-from typing import Any, Optional
+from typing import Any, Generator, List, Optional
class Node(object):
def __init__(self, value: Any) -> None:
- self.left = None
- self.right = None
+ """
+ Note: value can be anything as long as it is comparable.
+ Check out @functools.total_ordering.
+ """
+ self.left: Optional[Node] = None
+ self.right: Optional[Node] = None
self.value = value
-class BinaryTree(object):
+class BinarySearchTree(object):
def __init__(self):
self.root = None
self.count = 0
"""
Insert something into the tree.
- >>> t = BinaryTree()
+ >>> t = BinarySearchTree()
>>> t.insert(10)
>>> t.insert(20)
>>> t.insert(5)
Find an item in the tree and return its Node. Returns
None if the item is not in the tree.
- >>> t = BinaryTree()
+ >>> t = BinarySearchTree()
>>> t[99]
>>> t.insert(10)
"""Find helper"""
if value == node.value:
return node
- elif (value < node.value and node.left is not None):
+ elif value < node.value and node.left is not None:
return self._find(value, node.left)
- else:
- assert value > node.value
- if node.right is not None:
- return self._find(value, node.right)
+ elif value > node.value and node.right is not None:
+ return self._find(value, node.right)
return None
+ def _parent_path(
+ self, current: Optional[Node], target: Node
+ ) -> List[Optional[Node]]:
+ if current is None:
+ return [None]
+ ret: List[Optional[Node]] = [current]
+ if target.value == current.value:
+ return ret
+ elif target.value < current.value:
+ ret.extend(self._parent_path(current.left, target))
+ return ret
+ else:
+ assert target.value > current.value
+ ret.extend(self._parent_path(current.right, target))
+ return ret
+
+ def parent_path(self, node: Node) -> List[Optional[Node]]:
+ """Return a list of nodes representing the path from
+ the tree's root to the node argument. If the node does
+ not exist in the tree for some reason, the last element
+ on the path will be None but the path will indicate the
+ ancestor path of that node were it inserted.
+
+ >>> t = BinarySearchTree()
+ >>> t.insert(50)
+ >>> t.insert(75)
+ >>> t.insert(25)
+ >>> t.insert(12)
+ >>> t.insert(33)
+ >>> t.insert(4)
+ >>> t.insert(88)
+ >>> t
+ 50
+ ├──25
+ │ ├──12
+ │ │ └──4
+ │ └──33
+ └──75
+ └──88
+
+ >>> n = t[4]
+ >>> for x in t.parent_path(n):
+ ... print(x.value)
+ 50
+ 25
+ 12
+ 4
+
+ >>> del t[4]
+ >>> for x in t.parent_path(n):
+ ... if x is not None:
+ ... print(x.value)
+ ... else:
+ ... print(x)
+ 50
+ 25
+ 12
+ None
+
+ """
+ return self._parent_path(self.root, node)
+
def __delitem__(self, value: Any) -> bool:
"""
Delete an item from the tree and preserve the BST property.
- 50
- / \
- 25 75
- / / \
- 22 66 85
- /
- 13
-
-
- >>> t = BinaryTree()
+ >>> t = BinarySearchTree()
>>> t.insert(50)
>>> t.insert(75)
>>> t.insert(25)
>>> t.insert(22)
>>> t.insert(13)
>>> t.insert(85)
+ >>> t
+ 50
+ ├──25
+ │ └──22
+ │ └──13
+ └──75
+ ├──66
+ └──85
>>> for value in t.iterate_inorder():
... print(value)
50
66
85
+ >>> t
+ 50
+ ├──25
+ └──85
+ └──66
>>> t.__delitem__(99)
False
"""
Returns the count of items in the tree.
- >>> t = BinaryTree()
+ >>> t = BinarySearchTree()
>>> len(t)
0
>>> t.insert(50)
"""
Yield the tree's items in a preorder traversal sequence.
- >>> t = BinaryTree()
+ >>> t = BinarySearchTree()
>>> t.insert(50)
>>> t.insert(75)
>>> t.insert(25)
"""
Yield the tree's items in a preorder traversal sequence.
- >>> t = BinaryTree()
+ >>> t = BinarySearchTree()
>>> t.insert(50)
>>> t.insert(75)
>>> t.insert(25)
>>> t.insert(66)
>>> t.insert(22)
>>> t.insert(13)
+ >>> t.insert(24)
+ >>> t
+ 50
+ ├──25
+ │ └──22
+ │ ├──13
+ │ └──24
+ └──75
+ └──66
>>> for value in t.iterate_inorder():
... print(value)
13
22
+ 24
25
50
66
"""
Yield the tree's items in a preorder traversal sequence.
- >>> t = BinaryTree()
+ >>> t = BinarySearchTree()
>>> t.insert(50)
>>> t.insert(75)
>>> t.insert(25)
"""
Iterate only the leaf nodes in the tree.
- >>> t = BinaryTree()
+ >>> t = BinarySearchTree()
>>> t.insert(50)
>>> t.insert(75)
>>> t.insert(25)
if node.right is not None:
yield from self._iterate_by_depth(node.right, depth - 1)
- def iterate_nodes_by_depth(self, depth: int):
+ def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
"""
Iterate only the leaf nodes in the tree.
- >>> t = BinaryTree()
+ >>> t = BinarySearchTree()
>>> t.insert(50)
>>> t.insert(75)
>>> t.insert(25)
if self.root is not None:
yield from self._iterate_by_depth(self.root, depth)
+ def get_next_node(self, node: Node) -> Node:
+ """
+ Given a tree node, get the next greater node in the tree.
+
+ >>> t = BinarySearchTree()
+ >>> t.insert(50)
+ >>> t.insert(75)
+ >>> t.insert(25)
+ >>> t.insert(66)
+ >>> t.insert(22)
+ >>> t.insert(13)
+ >>> t.insert(23)
+ >>> t
+ 50
+ ├──25
+ │ └──22
+ │ ├──13
+ │ └──23
+ └──75
+ └──66
+
+ >>> n = t[23]
+ >>> t.get_next_node(n).value
+ 25
+
+ >>> n = t[50]
+ >>> t.get_next_node(n).value
+ 66
+
+ """
+ if node.right is not None:
+ x = node.right
+ while x.left is not None:
+ x = x.left
+ return x
+
+ path = self.parent_path(node)
+ assert path[-1] is not None
+ assert path[-1] == node
+ path = path[:-1]
+ path.reverse()
+ for ancestor in path:
+ assert ancestor is not None
+ if node != ancestor.right:
+ return ancestor
+ node = ancestor
+ raise Exception()
+
def _depth(self, node: Node, sofar: int) -> int:
depth_left = sofar + 1
depth_right = sofar + 1
Returns the max height (depth) of the tree in plies (edge distance
from root).
- >>> t = BinaryTree()
+ >>> t = BinarySearchTree()
>>> t.depth()
0
def height(self):
return self.depth()
+ def repr_traverse(
+ self, padding: str, pointer: str, node: Optional[Node], has_right_sibling: bool
+ ) -> str:
+ if node is not None:
+ viz = f'\n{padding}{pointer}{node.value}'
+ if has_right_sibling:
+ padding += "│ "
+ else:
+ padding += ' '
+
+ pointer_right = "└──"
+ if node.right is not None:
+ pointer_left = "├──"
+ else:
+ pointer_left = "└──"
+
+ viz += self.repr_traverse(
+ padding, pointer_left, node.left, node.right is not None
+ )
+ viz += self.repr_traverse(padding, pointer_right, node.right, False)
+ return viz
+ return ""
+
+ def __repr__(self):
+ """
+ Draw the tree in ASCII.
+
+ >>> t = BinarySearchTree()
+ >>> t.insert(50)
+ >>> t.insert(25)
+ >>> t.insert(75)
+ >>> t.insert(12)
+ >>> t.insert(33)
+ >>> t.insert(88)
+ >>> t.insert(55)
+ >>> t
+ 50
+ ├──25
+ │ ├──12
+ │ └──33
+ └──75
+ ├──55
+ └──88
+ """
+ if self.root is None:
+ return ""
+
+ ret = f'{self.root.value}'
+ pointer_right = "└──"
+ if self.root.right is None:
+ pointer_left = "└──"
+ else:
+ pointer_left = "├──"
+
+ ret += self.repr_traverse(
+ '', pointer_left, self.root.left, self.root.left is not None
+ )
+ ret += self.repr_traverse('', pointer_right, self.root.right, False)
+ return ret
+
if __name__ == '__main__':
import doctest
+
doctest.testmod()