3 # © Copyright 2021-2022, Scott Gasch
5 """A binary search tree implementation."""
7 from typing import Any, Generator, List, Optional
11 def __init__(self, value: Any) -> None:
13 A BST node. Note that value can be anything as long as it
14 is comparable. Check out :meth:`functools.total_ordering`
15 (https://docs.python.org/3/library/functools.html#functools.total_ordering)
18 value: a reference to the value of the node.
20 self.left: Optional[Node] = None
21 self.right: Optional[Node] = None
25 class BinarySearchTree(object):
31 def get_root(self) -> Optional[Node]:
39 def _on_insert(self, parent: Optional[Node], new: Node) -> None:
40 """This is called immediately _after_ a new node is inserted."""
43 def insert(self, value: Any) -> None:
45 Insert something into the tree.
48 value: the value to be inserted.
50 >>> t = BinarySearchTree()
57 >>> t.get_root().value
62 self.root = Node(value)
64 self._on_insert(None, self.root)
66 self._insert(value, self.root)
68 def _insert(self, value: Any, node: Node):
69 """Insertion helper"""
70 if value < node.value:
71 if node.left is not None:
72 self._insert(value, node.left)
74 node.left = Node(value)
76 self._on_insert(node, node.left)
78 if node.right is not None:
79 self._insert(value, node.right)
81 node.right = Node(value)
83 self._on_insert(node, node.right)
85 def __getitem__(self, value: Any) -> Optional[Node]:
87 Find an item in the tree and return its Node. Returns
88 None if the item is not in the tree.
90 >>> t = BinarySearchTree()
102 if self.root is not None:
103 return self._find(value, self.root)
106 def _find(self, value: Any, node: Node) -> Optional[Node]:
108 if value == node.value:
110 elif value < node.value and node.left is not None:
111 return self._find(value, node.left)
112 elif value > node.value and node.right is not None:
113 return self._find(value, node.right)
117 self, current: Optional[Node], target: Node
118 ) -> List[Optional[Node]]:
119 """Internal helper"""
122 ret: List[Optional[Node]] = [current]
123 if target.value == current.value:
125 elif target.value < current.value:
126 ret.extend(self._parent_path(current.left, target))
129 assert target.value > current.value
130 ret.extend(self._parent_path(current.right, target))
133 def parent_path(self, node: Node) -> List[Optional[Node]]:
134 """Get a node's parent path.
137 node: the node to check
140 a list of nodes representing the path from
141 the tree's root to the node.
145 If the node does not exist in the tree, the last element
146 on the path will be None but the path will indicate the
147 ancestor path of that node were it to be inserted.
149 >>> t = BinarySearchTree()
167 >>> for x in t.parent_path(n):
175 >>> for x in t.parent_path(n):
176 ... if x is not None:
186 return self._parent_path(self.root, node)
188 def __delitem__(self, value: Any) -> bool:
190 Delete an item from the tree and preserve the BST property.
193 value: the value of the node to be deleted.
196 True if the value was found and its associated node was
197 successfully deleted and False otherwise.
199 >>> t = BinarySearchTree()
216 >>> for value in t.iterate_inorder():
226 >>> del t[22] # Note: bool result is discarded
228 >>> for value in t.iterate_inorder():
237 >>> t.__delitem__(13)
239 >>> for value in t.iterate_inorder():
247 >>> t.__delitem__(75)
249 >>> for value in t.iterate_inorder():
261 >>> t.__delitem__(99)
265 if self.root is not None:
266 ret = self._delete(value, None, self.root)
274 def _on_delete(self, parent: Optional[Node], deleted: Node) -> None:
275 """This is called just after deleted was deleted from the tree"""
278 def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool:
280 if node.value == value:
282 # Deleting a leaf node
283 if node.left is None and node.right is None:
284 if parent is not None:
285 if parent.left == node:
288 assert parent.right == node
290 self._on_delete(parent, node)
293 # Node only has a right.
294 elif node.left is None:
295 assert node.right is not None
296 if parent is not None:
297 if parent.left == node:
298 parent.left = node.right
300 assert parent.right == node
301 parent.right = node.right
302 self._on_delete(parent, node)
305 # Node only has a left.
306 elif node.right is None:
307 assert node.left is not None
308 if parent is not None:
309 if parent.left == node:
310 parent.left = node.left
312 assert parent.right == node
313 parent.right = node.left
314 self._on_delete(parent, node)
317 # Node has both a left and right.
319 assert node.left is not None and node.right is not None
320 descendent = node.right
321 while descendent.left is not None:
322 descendent = descendent.left
323 node.value = descendent.value
324 return self._delete(node.value, node, node.right)
325 elif value < node.value and node.left is not None:
326 return self._delete(value, node, node.left)
327 elif value > node.value and node.right is not None:
328 return self._delete(value, node, node.right)
334 The count of items in the tree.
336 >>> t = BinarySearchTree()
342 >>> t.__delitem__(50)
358 def __contains__(self, value: Any) -> bool:
361 True if the item is in the tree; False otherwise.
363 return self.__getitem__(value) is not None
365 def _iterate_preorder(self, node: Node):
367 if node.left is not None:
368 yield from self._iterate_preorder(node.left)
369 if node.right is not None:
370 yield from self._iterate_preorder(node.right)
372 def _iterate_inorder(self, node: Node):
373 if node.left is not None:
374 yield from self._iterate_inorder(node.left)
376 if node.right is not None:
377 yield from self._iterate_inorder(node.right)
379 def _iterate_postorder(self, node: Node):
380 if node.left is not None:
381 yield from self._iterate_postorder(node.left)
382 if node.right is not None:
383 yield from self._iterate_postorder(node.right)
386 def iterate_preorder(self):
389 A Generator that yields the tree's items in a
390 preorder traversal sequence.
392 >>> t = BinarySearchTree()
400 >>> for value in t.iterate_preorder():
410 if self.root is not None:
411 yield from self._iterate_preorder(self.root)
413 def iterate_inorder(self):
416 A Generator that yield the tree's items in a preorder
419 >>> t = BinarySearchTree()
436 >>> for value in t.iterate_inorder():
447 if self.root is not None:
448 yield from self._iterate_inorder(self.root)
450 def iterate_postorder(self):
453 A Generator that yield the tree's items in a preorder
456 >>> t = BinarySearchTree()
464 >>> for value in t.iterate_postorder():
474 if self.root is not None:
475 yield from self._iterate_postorder(self.root)
477 def _iterate_leaves(self, node: Node):
478 if node.left is not None:
479 yield from self._iterate_leaves(node.left)
480 if node.right is not None:
481 yield from self._iterate_leaves(node.right)
482 if node.left is None and node.right is None:
485 def iterate_leaves(self):
488 A Gemerator that yielde only the leaf nodes in the
491 >>> t = BinarySearchTree()
499 >>> for value in t.iterate_leaves():
505 if self.root is not None:
506 yield from self._iterate_leaves(self.root)
508 def _iterate_by_depth(self, node: Node, depth: int):
513 if node.left is not None:
514 yield from self._iterate_by_depth(node.left, depth - 1)
515 if node.right is not None:
516 yield from self._iterate_by_depth(node.right, depth - 1)
518 def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
521 depth: the desired depth
524 A Generator that yields nodes at the prescribed depth in
527 >>> t = BinarySearchTree()
535 >>> for value in t.iterate_nodes_by_depth(2):
540 >>> for value in t.iterate_nodes_by_depth(3):
545 if self.root is not None:
546 yield from self._iterate_by_depth(self.root, depth)
548 def get_next_node(self, node: Node) -> Node:
551 node: the node whose next greater successor is desired
554 Given a tree node, returns the next greater node in the tree.
556 >>> t = BinarySearchTree()
574 >>> t.get_next_node(n).value
578 >>> t.get_next_node(n).value
582 if node.right is not None:
584 while x.left is not None:
588 path = self.parent_path(node)
589 assert path[-1] is not None
590 assert path[-1] == node
593 for ancestor in path:
594 assert ancestor is not None
595 if node != ancestor.right:
600 def _depth(self, node: Node, sofar: int) -> int:
601 depth_left = sofar + 1
602 depth_right = sofar + 1
603 if node.left is not None:
604 depth_left = self._depth(node.left, sofar + 1)
605 if node.right is not None:
606 depth_right = self._depth(node.right, sofar + 1)
607 return max(depth_left, depth_right)
609 def depth(self) -> int:
612 The max height (depth) of the tree in plies (edge distance
615 >>> t = BinarySearchTree()
637 if self.root is None:
639 return self._depth(self.root, 0)
641 def height(self) -> int:
642 """Returns the height (i.e. max depth) of the tree"""
649 node: Optional[Node],
650 has_right_sibling: bool,
653 viz = f"\n{padding}{pointer}{node.value}"
654 if has_right_sibling:
659 pointer_right = "└──"
660 if node.right is not None:
665 viz += self.repr_traverse(
666 padding, pointer_left, node.left, node.right is not None
668 viz += self.repr_traverse(padding, pointer_right, node.right, False)
675 An ASCII string representation of the tree.
677 >>> t = BinarySearchTree()
694 if self.root is None:
697 ret = f"{self.root.value}"
698 pointer_right = "└──"
699 if self.root.right is None:
704 ret += self.repr_traverse(
705 "", pointer_left, self.root.left, self.root.left is not None
707 ret += self.repr_traverse("", pointer_right, self.root.right, False)
711 if __name__ == "__main__":