3 # © Copyright 2021-2023, Scott Gasch
5 """A binary search tree implementation."""
7 from abc import abstractmethod
8 from typing import Any, Generator, List, Optional, Protocol
11 class Comparable(Protocol):
13 def __lt__(self, other: Any) -> bool:
17 def __le__(self, other: Any) -> bool:
21 def __eq__(self, other: Any) -> bool:
26 def __init__(self, value: Comparable) -> None:
27 """A BST node. Note that value can be anything as long as it
28 is comparable with other instances of itself. Check out
29 :meth:`functools.total_ordering`
30 (https://docs.python.org/3/library/functools.html#functools.total_ordering)
33 value: a reference to the value of the node. Must be comparable
37 self.left: Optional[Node] = None
38 self.right: Optional[Node] = None
39 self.value: Comparable = value
42 class BinarySearchTree(object):
48 def get_root(self) -> Optional[Node]:
56 def _on_insert(self, parent: Optional[Node], new: Node) -> None:
57 """This is called immediately _after_ a new node is inserted."""
60 def insert(self, value: Comparable) -> None:
62 Insert something into the tree.
65 value: the value to be inserted.
67 >>> t = BinarySearchTree()
74 >>> t.get_root().value
79 self.root = Node(value)
81 self._on_insert(None, self.root)
83 self._insert(value, self.root)
85 def _insert(self, value: Comparable, node: Node):
86 """Insertion helper"""
87 if value < node.value:
88 if node.left is not None:
89 self._insert(value, node.left)
91 node.left = Node(value)
93 self._on_insert(node, node.left)
95 if node.right is not None:
96 self._insert(value, node.right)
98 node.right = Node(value)
100 self._on_insert(node, node.right)
102 def __getitem__(self, value: Comparable) -> Optional[Node]:
104 Find an item in the tree and return its Node. Returns
105 None if the item is not in the tree.
107 >>> t = BinarySearchTree()
119 if self.root is not None:
120 return self._find_exact(value, self.root)
123 def _find_exact(self, target: Comparable, node: Node) -> Optional[Node]:
124 """Recursively traverse the tree looking for a node with the
125 target value. Return that node if it exists, otherwise return
128 if target == node.value:
130 elif target < node.value and node.left is not None:
131 return self._find_exact(target, node.left)
132 elif target > node.value and node.right is not None:
133 return self._find_exact(target, node.right)
136 def _find_lowest_node_less_than_or_equal_to(
137 self, target: Comparable, node: Optional[Node]
139 """Find helper that returns the lowest node that is less
140 than or equal to the target value. Returns None if target is
141 lower than the lowest node in the tree.
143 >>> t = BinarySearchTree()
160 >>> t._find_lowest_node_less_than_or_equal_to(48, t.root).value
162 >>> t._find_lowest_node_less_than_or_equal_to(55, t.root).value
164 >>> t._find_lowest_node_less_than_or_equal_to(100, t.root).value
166 >>> t._find_lowest_node_less_than_or_equal_to(24, t.root).value
168 >>> t._find_lowest_node_less_than_or_equal_to(20, t.root).value
170 >>> t._find_lowest_node_less_than_or_equal_to(72, t.root).value
172 >>> t._find_lowest_node_less_than_or_equal_to(78, t.root).value
174 >>> t._find_lowest_node_less_than_or_equal_to(12, t.root) is None
182 if target == node.value:
185 elif target > node.value:
186 if below := self._find_lowest_node_less_than_or_equal_to(
194 return self._find_lowest_node_less_than_or_equal_to(target, node.left)
196 def _find_lowest_node_greater_than_or_equal_to(
197 self, target: Comparable, node: Optional[Node]
199 """Find helper that returns the lowest node that is greater
200 than or equal to the target value. Returns None if target is
201 higher than the greatest node in the tree.
203 >>> t = BinarySearchTree()
220 >>> t._find_lowest_node_greater_than_or_equal_to(48, t.root).value
222 >>> t._find_lowest_node_greater_than_or_equal_to(55, t.root).value
224 >>> t._find_lowest_node_greater_than_or_equal_to(1, t.root).value
226 >>> t._find_lowest_node_greater_than_or_equal_to(24, t.root).value
228 >>> t._find_lowest_node_greater_than_or_equal_to(20, t.root).value
230 >>> t._find_lowest_node_greater_than_or_equal_to(72, t.root).value
232 >>> t._find_lowest_node_greater_than_or_equal_to(78, t.root).value
234 >>> t._find_lowest_node_greater_than_or_equal_to(95, t.root) is None
242 if target == node.value:
245 elif target > node.value:
246 return self._find_lowest_node_greater_than_or_equal_to(target, node.right)
248 # If target < this node's value, either this node is the
249 # answer or the answer is in this node's left subtree.
251 if below := self._find_lowest_node_greater_than_or_equal_to(
259 self, current: Optional[Node], target: Node
260 ) -> List[Optional[Node]]:
261 """Internal helper"""
264 ret: List[Optional[Node]] = [current]
265 if target.value == current.value:
267 elif target.value < current.value:
268 ret.extend(self._parent_path(current.left, target))
271 assert target.value > current.value
272 ret.extend(self._parent_path(current.right, target))
275 def parent_path(self, node: Node) -> List[Optional[Node]]:
276 """Get a node's parent path.
279 node: the node to check
282 a list of nodes representing the path from
283 the tree's root to the node.
287 If the node does not exist in the tree, the last element
288 on the path will be None but the path will indicate the
289 ancestor path of that node were it to be inserted.
291 >>> t = BinarySearchTree()
309 >>> for x in t.parent_path(n):
317 >>> for x in t.parent_path(n):
318 ... if x is not None:
328 return self._parent_path(self.root, node)
330 def __delitem__(self, value: Comparable) -> bool:
332 Delete an item from the tree and preserve the BST property.
335 value: the value of the node to be deleted.
338 True if the value was found and its associated node was
339 successfully deleted and False otherwise.
341 >>> t = BinarySearchTree()
358 >>> for value in t.iterate_inorder():
368 >>> del t[22] # Note: bool result is discarded
370 >>> for value in t.iterate_inorder():
379 >>> t.__delitem__(13)
381 >>> for value in t.iterate_inorder():
389 >>> t.__delitem__(75)
391 >>> for value in t.iterate_inorder():
403 >>> t.__delitem__(99)
407 if self.root is not None:
408 ret = self._delete(value, None, self.root)
416 def _on_delete(self, parent: Optional[Node], deleted: Node) -> None:
417 """This is called just after deleted was deleted from the tree"""
420 def _delete(self, value: Comparable, parent: Optional[Node], node: Node) -> bool:
422 if node.value == value:
424 # Deleting a leaf node
425 if node.left is None and node.right is None:
426 if parent is not None:
427 if parent.left == node:
430 assert parent.right == node
432 self._on_delete(parent, node)
435 # Node only has a right.
436 elif node.left is None:
437 assert node.right is not None
438 if parent is not None:
439 if parent.left == node:
440 parent.left = node.right
442 assert parent.right == node
443 parent.right = node.right
444 self._on_delete(parent, node)
447 # Node only has a left.
448 elif node.right is None:
449 assert node.left is not None
450 if parent is not None:
451 if parent.left == node:
452 parent.left = node.left
454 assert parent.right == node
455 parent.right = node.left
456 self._on_delete(parent, node)
459 # Node has both a left and right.
461 assert node.left is not None and node.right is not None
462 descendent = node.right
463 while descendent.left is not None:
464 descendent = descendent.left
465 node.value = descendent.value
466 return self._delete(node.value, node, node.right)
467 elif value < node.value and node.left is not None:
468 return self._delete(value, node, node.left)
469 elif value > node.value and node.right is not None:
470 return self._delete(value, node, node.right)
476 The count of items in the tree.
478 >>> t = BinarySearchTree()
484 >>> t.__delitem__(50)
500 def __contains__(self, value: Comparable) -> bool:
503 True if the item is in the tree; False otherwise.
505 return self.__getitem__(value) is not None
507 def _iterate_preorder(self, node: Node):
509 if node.left is not None:
510 yield from self._iterate_preorder(node.left)
511 if node.right is not None:
512 yield from self._iterate_preorder(node.right)
514 def _iterate_inorder(self, node: Node):
515 if node.left is not None:
516 yield from self._iterate_inorder(node.left)
518 if node.right is not None:
519 yield from self._iterate_inorder(node.right)
521 def _iterate_postorder(self, node: Node):
522 if node.left is not None:
523 yield from self._iterate_postorder(node.left)
524 if node.right is not None:
525 yield from self._iterate_postorder(node.right)
528 def iterate_preorder(self):
531 A Generator that yields the tree's items in a
532 preorder traversal sequence.
534 >>> t = BinarySearchTree()
542 >>> for value in t.iterate_preorder():
552 if self.root is not None:
553 yield from self._iterate_preorder(self.root)
555 def iterate_inorder(self):
558 A Generator that yield the tree's items in a preorder
561 >>> t = BinarySearchTree()
578 >>> for value in t.iterate_inorder():
589 if self.root is not None:
590 yield from self._iterate_inorder(self.root)
592 def iterate_postorder(self):
595 A Generator that yield the tree's items in a preorder
598 >>> t = BinarySearchTree()
606 >>> for value in t.iterate_postorder():
616 if self.root is not None:
617 yield from self._iterate_postorder(self.root)
619 def _iterate_leaves(self, node: Node):
620 if node.left is not None:
621 yield from self._iterate_leaves(node.left)
622 if node.right is not None:
623 yield from self._iterate_leaves(node.right)
624 if node.left is None and node.right is None:
627 def iterate_leaves(self):
630 A Gemerator that yielde only the leaf nodes in the
633 >>> t = BinarySearchTree()
641 >>> for value in t.iterate_leaves():
647 if self.root is not None:
648 yield from self._iterate_leaves(self.root)
650 def _iterate_by_depth(self, node: Node, depth: int):
655 if node.left is not None:
656 yield from self._iterate_by_depth(node.left, depth - 1)
657 if node.right is not None:
658 yield from self._iterate_by_depth(node.right, depth - 1)
660 def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
663 depth: the desired depth
666 A Generator that yields nodes at the prescribed depth in
669 >>> t = BinarySearchTree()
677 >>> for value in t.iterate_nodes_by_depth(2):
682 >>> for value in t.iterate_nodes_by_depth(3):
687 if self.root is not None:
688 yield from self._iterate_by_depth(self.root, depth)
690 def get_next_node(self, node: Node) -> Optional[Node]:
693 node: the node whose next greater successor is desired
696 Given a tree node, returns the next greater node in the tree.
697 If the given node is the greatest node in the tree, returns None.
699 >>> t = BinarySearchTree()
717 >>> t.get_next_node(n).value
721 >>> t.get_next_node(n).value
725 >>> t.get_next_node(n) is None
729 if node.right is not None:
731 while x.left is not None:
735 path = self.parent_path(node)
736 assert path[-1] is not None
737 assert path[-1] == node
740 for ancestor in path:
741 assert ancestor is not None
742 if node != ancestor.right:
747 def get_nodes_in_range_inclusive(self, lower: Comparable, upper: Comparable):
749 >>> t = BinarySearchTree()
766 >>> for node in t.get_nodes_in_range_inclusive(21, 74):
767 ... print(node.value)
774 node: Optional[Node] = self._find_lowest_node_greater_than_or_equal_to(
778 if lower <= node.value <= upper:
780 node = self.get_next_node(node)
782 def _depth(self, node: Node, sofar: int) -> int:
783 depth_left = sofar + 1
784 depth_right = sofar + 1
785 if node.left is not None:
786 depth_left = self._depth(node.left, sofar + 1)
787 if node.right is not None:
788 depth_right = self._depth(node.right, sofar + 1)
789 return max(depth_left, depth_right)
791 def depth(self) -> int:
794 The max height (depth) of the tree in plies (edge distance
797 >>> t = BinarySearchTree()
819 if self.root is None:
821 return self._depth(self.root, 0)
823 def height(self) -> int:
824 """Returns the height (i.e. max depth) of the tree"""
831 node: Optional[Node],
832 has_right_sibling: bool,
835 viz = f"\n{padding}{pointer}{node.value}"
836 if has_right_sibling:
841 pointer_right = "└──"
842 if node.right is not None:
847 viz += self.repr_traverse(
848 padding, pointer_left, node.left, node.right is not None
850 viz += self.repr_traverse(padding, pointer_right, node.right, False)
857 An ASCII string representation of the tree.
859 >>> t = BinarySearchTree()
876 if self.root is None:
879 ret = f"{self.root.value}"
880 pointer_right = "└──"
881 if self.root.right is None:
886 ret += self.repr_traverse(
887 "", pointer_left, self.root.left, self.root.left is not None
889 ret += self.repr_traverse("", pointer_right, self.root.right, False)
893 if __name__ == "__main__":