#!/usr/bin/env python3
-from typing import Any, Optional
+from typing import Any, Optional, List
class Node(object):
def __init__(self, value: Any) -> None:
+ """
+ Note: value can be anything as long as it is comparable.
+ Check out @functools.total_ordering.
+ """
self.left = None
self.right = None
self.value = value
-class BinaryTree(object):
+class BinarySearchTree(object):
def __init__(self):
self.root = None
self.count = 0
"""
Insert something into the tree.
- >>> t = BinaryTree()
+ >>> t = BinarySearchTree()
>>> t.insert(10)
>>> t.insert(20)
>>> t.insert(5)
Find an item in the tree and return its Node. Returns
None if the item is not in the tree.
- >>> t = BinaryTree()
+ >>> t = BinarySearchTree()
>>> t[99]
>>> t.insert(10)
return node
elif (value < node.value and node.left is not None):
return self._find(value, node.left)
- else:
- assert value > node.value
- if node.right is not None:
- return self._find(value, node.right)
+ elif (value > node.value and node.right is not None):
+ return self._find(value, node.right)
return None
+ def _parent_path(self, current: Node, target: Node):
+ if current is None:
+ return [None]
+ ret = [current]
+ if target.value == current.value:
+ return ret
+ elif target.value < current.value:
+ ret.extend(self._parent_path(current.left, target))
+ return ret
+ else:
+ assert target.value > current.value
+ ret.extend(self._parent_path(current.right, target))
+ return ret
+
+ def parent_path(self, node: Node) -> Optional[List[Node]]:
+ """Return a list of nodes representing the path from
+ the tree's root to the node argument. If the node does
+ not exist in the tree for some reason, the last element
+ on the path will be None but the path will indicate the
+ ancestor path of that node were it inserted.
+
+ >>> t = BinarySearchTree()
+ >>> t.insert(50)
+ >>> t.insert(75)
+ >>> t.insert(25)
+ >>> t.insert(12)
+ >>> t.insert(33)
+ >>> t.insert(4)
+ >>> t.insert(88)
+ >>> t
+ 50
+ ├──25
+ │ ├──12
+ │ │ └──4
+ │ └──33
+ └──75
+ └──88
+
+ >>> n = t[4]
+ >>> for x in t.parent_path(n):
+ ... print(x.value)
+ 50
+ 25
+ 12
+ 4
+
+ >>> del t[4]
+ >>> for x in t.parent_path(n):
+ ... if x is not None:
+ ... print(x.value)
+ ... else:
+ ... print(x)
+ 50
+ 25
+ 12
+ None
+
+ """
+ return self._parent_path(self.root, node)
+
def __delitem__(self, value: Any) -> bool:
"""
Delete an item from the tree and preserve the BST property.
- 50
- / \
- 25 75
- / / \
- 22 66 85
- /
- 13
-
-
- >>> t = BinaryTree()
+ >>> t = BinarySearchTree()
>>> t.insert(50)
>>> t.insert(75)
>>> t.insert(25)
>>> t.insert(22)
>>> t.insert(13)
>>> t.insert(85)
+ >>> t
+ 50
+ ├──25
+ │ └──22
+ │ └──13
+ └──75
+ ├──66
+ └──85
>>> for value in t.iterate_inorder():
... print(value)
50
66
85
+ >>> t
+ 50
+ ├──25
+ └──85
+ └──66
>>> t.__delitem__(99)
False
"""
Returns the count of items in the tree.
- >>> t = BinaryTree()
+ >>> t = BinarySearchTree()
>>> len(t)
0
>>> t.insert(50)
"""
Yield the tree's items in a preorder traversal sequence.
- >>> t = BinaryTree()
+ >>> t = BinarySearchTree()
>>> t.insert(50)
>>> t.insert(75)
>>> t.insert(25)
"""
Yield the tree's items in a preorder traversal sequence.
- >>> t = BinaryTree()
+ >>> t = BinarySearchTree()
>>> t.insert(50)
>>> t.insert(75)
>>> t.insert(25)
>>> t.insert(66)
>>> t.insert(22)
>>> t.insert(13)
+ >>> t.insert(24)
+ >>> t
+ 50
+ ├──25
+ │ └──22
+ │ ├──13
+ │ └──24
+ └──75
+ └──66
>>> for value in t.iterate_inorder():
... print(value)
13
22
+ 24
25
50
66
"""
Yield the tree's items in a preorder traversal sequence.
- >>> t = BinaryTree()
+ >>> t = BinarySearchTree()
>>> t.insert(50)
>>> t.insert(75)
>>> t.insert(25)
"""
Iterate only the leaf nodes in the tree.
- >>> t = BinaryTree()
+ >>> t = BinarySearchTree()
>>> t.insert(50)
>>> t.insert(75)
>>> t.insert(25)
"""
Iterate only the leaf nodes in the tree.
- >>> t = BinaryTree()
+ >>> t = BinarySearchTree()
>>> t.insert(50)
>>> t.insert(75)
>>> t.insert(25)
if self.root is not None:
yield from self._iterate_by_depth(self.root, depth)
+ def get_next_node(self, node: Node) -> Node:
+ """
+ Given a tree node, get the next greater node in the tree.
+
+ >>> t = BinarySearchTree()
+ >>> t.insert(50)
+ >>> t.insert(75)
+ >>> t.insert(25)
+ >>> t.insert(66)
+ >>> t.insert(22)
+ >>> t.insert(13)
+ >>> t.insert(23)
+ >>> t
+ 50
+ ├──25
+ │ └──22
+ │ ├──13
+ │ └──23
+ └──75
+ └──66
+
+ >>> n = t[23]
+ >>> t.get_next_node(n).value
+ 25
+
+ >>> n = t[50]
+ >>> t.get_next_node(n).value
+ 66
+
+ """
+ if node.right is not None:
+ x = node.right
+ while x.left is not None:
+ x = x.left
+ return x
+
+ path = self.parent_path(node)
+ assert path[-1] == node
+ path = path[:-1]
+ path.reverse()
+ for ancestor in path:
+ if node != ancestor.right:
+ return ancestor
+ node = ancestor
+
def _depth(self, node: Node, sofar: int) -> int:
depth_left = sofar + 1
depth_right = sofar + 1
Returns the max height (depth) of the tree in plies (edge distance
from root).
- >>> t = BinaryTree()
+ >>> t = BinarySearchTree()
>>> t.depth()
0
def height(self):
return self.depth()
+ def repr_traverse(self, padding: str, pointer: str, node: Node, has_right_sibling: bool) -> str:
+ if node is not None:
+ viz = f'\n{padding}{pointer}{node.value}'
+ if has_right_sibling:
+ padding += "│ "
+ else:
+ padding += ' '
+
+ pointer_right = "└──"
+ if node.right is not None:
+ pointer_left = "├──"
+ else:
+ pointer_left = "└──"
+
+ viz += self.repr_traverse(padding, pointer_left, node.left, node.right is not None)
+ viz += self.repr_traverse(padding, pointer_right, node.right, False)
+ return viz
+ return ""
+
+ def __repr__(self):
+ """
+ Draw the tree in ASCII.
+
+ >>> t = BinarySearchTree()
+ >>> t.insert(50)
+ >>> t.insert(25)
+ >>> t.insert(75)
+ >>> t.insert(12)
+ >>> t.insert(33)
+ >>> t.insert(88)
+ >>> t.insert(55)
+ >>> t
+ 50
+ ├──25
+ │ ├──12
+ │ └──33
+ └──75
+ ├──55
+ └──88
+ """
+ if self.root is None:
+ return ""
+
+ ret = f'{self.root.value}'
+ pointer_right = "└──"
+ if self.root.right is None:
+ pointer_left = "└──"
+ else:
+ pointer_left = "├──"
+
+ ret += self.repr_traverse('', pointer_left, self.root.left, self.root.left is not None)
+ ret += self.repr_traverse('', pointer_right, self.root.right, False)
+ return ret
+
if __name__ == '__main__':
import doctest