3 # © Copyright 2021-2023, Scott Gasch
5 """A binary search tree implementation."""
7 from abc import abstractmethod
8 from typing import Any, Generator, List, Optional, Protocol
11 class Comparable(Protocol):
12 """Anything that implements basic comparison methods such that it
13 can be compared to other instances of the same type.
15 Check out :meth:`functools.total_ordering`
16 (https://docs.python.org/3/library/functools.html#functools.total_ordering)
17 for an easy way to make your type comparable.
21 def __lt__(self, other: Any) -> bool:
25 def __le__(self, other: Any) -> bool:
29 def __eq__(self, other: Any) -> bool:
34 def __init__(self, value: Comparable) -> None:
35 """A BST node. Just a left and right reference along with a
36 value. Note that value can be anything as long as it
37 is :class:`Comparable` with other instances of itself.
40 value: a reference to the value of the node. Must be
41 :class:`Comparable` to other values.
44 self.left: Optional[Node] = None
45 self.right: Optional[Node] = None
46 self.value: Comparable = value
49 class BinarySearchTree(object):
55 def get_root(self) -> Optional[Node]:
63 def _on_insert(self, parent: Optional[Node], new: Node) -> None:
64 """This is called immediately _after_ a new node is inserted."""
67 def insert(self, value: Comparable) -> None:
69 Insert something into the tree.
72 value: the value to be inserted.
74 >>> t = BinarySearchTree()
81 >>> t.get_root().value
86 self.root = Node(value)
88 self._on_insert(None, self.root)
90 self._insert(value, self.root)
92 def _insert(self, value: Comparable, node: Node):
93 """Insertion helper"""
94 if value < node.value:
95 if node.left is not None:
96 self._insert(value, node.left)
98 node.left = Node(value)
100 self._on_insert(node, node.left)
102 if node.right is not None:
103 self._insert(value, node.right)
105 node.right = Node(value)
107 self._on_insert(node, node.right)
109 def __getitem__(self, value: Comparable) -> Optional[Node]:
111 Find an item in the tree and return its Node. Returns
112 None if the item is not in the tree.
114 >>> t = BinarySearchTree()
126 if self.root is not None:
127 return self._find_exact(value, self.root)
130 def _find_exact(self, target: Comparable, node: Node) -> Optional[Node]:
131 """Recursively traverse the tree looking for a node with the
132 target value. Return that node if it exists, otherwise return
135 if target == node.value:
137 elif target < node.value and node.left is not None:
138 return self._find_exact(target, node.left)
139 elif target > node.value and node.right is not None:
140 return self._find_exact(target, node.right)
143 def _find_lowest_node_less_than_or_equal_to(
144 self, target: Comparable, node: Optional[Node]
146 """Find helper that returns the lowest node that is less
147 than or equal to the target value. Returns None if target is
148 lower than the lowest node in the tree.
150 >>> t = BinarySearchTree()
167 >>> t._find_lowest_node_less_than_or_equal_to(48, t.root).value
169 >>> t._find_lowest_node_less_than_or_equal_to(55, t.root).value
171 >>> t._find_lowest_node_less_than_or_equal_to(100, t.root).value
173 >>> t._find_lowest_node_less_than_or_equal_to(24, t.root).value
175 >>> t._find_lowest_node_less_than_or_equal_to(20, t.root).value
177 >>> t._find_lowest_node_less_than_or_equal_to(72, t.root).value
179 >>> t._find_lowest_node_less_than_or_equal_to(78, t.root).value
181 >>> t._find_lowest_node_less_than_or_equal_to(12, t.root) is None
189 if target == node.value:
192 elif target > node.value:
193 if below := self._find_lowest_node_less_than_or_equal_to(
201 return self._find_lowest_node_less_than_or_equal_to(target, node.left)
203 def _find_lowest_node_greater_than_or_equal_to(
204 self, target: Comparable, node: Optional[Node]
206 """Find helper that returns the lowest node that is greater
207 than or equal to the target value. Returns None if target is
208 higher than the greatest node in the tree.
210 >>> t = BinarySearchTree()
227 >>> t._find_lowest_node_greater_than_or_equal_to(48, t.root).value
229 >>> t._find_lowest_node_greater_than_or_equal_to(55, t.root).value
231 >>> t._find_lowest_node_greater_than_or_equal_to(1, t.root).value
233 >>> t._find_lowest_node_greater_than_or_equal_to(24, t.root).value
235 >>> t._find_lowest_node_greater_than_or_equal_to(20, t.root).value
237 >>> t._find_lowest_node_greater_than_or_equal_to(72, t.root).value
239 >>> t._find_lowest_node_greater_than_or_equal_to(78, t.root).value
241 >>> t._find_lowest_node_greater_than_or_equal_to(95, t.root) is None
249 if target == node.value:
252 elif target > node.value:
253 return self._find_lowest_node_greater_than_or_equal_to(target, node.right)
255 # If target < this node's value, either this node is the
256 # answer or the answer is in this node's left subtree.
258 if below := self._find_lowest_node_greater_than_or_equal_to(
266 self, current: Optional[Node], target: Node
267 ) -> List[Optional[Node]]:
268 """Internal helper"""
271 ret: List[Optional[Node]] = [current]
272 if target.value == current.value:
274 elif target.value < current.value:
275 ret.extend(self._parent_path(current.left, target))
278 assert target.value > current.value
279 ret.extend(self._parent_path(current.right, target))
282 def parent_path(self, node: Node) -> List[Optional[Node]]:
283 """Get a node's parent path.
286 node: the node to check
289 a list of nodes representing the path from
290 the tree's root to the node.
294 If the node does not exist in the tree, the last element
295 on the path will be None but the path will indicate the
296 ancestor path of that node were it to be inserted.
298 >>> t = BinarySearchTree()
316 >>> for x in t.parent_path(n):
324 >>> for x in t.parent_path(n):
325 ... if x is not None:
335 return self._parent_path(self.root, node)
337 def __delitem__(self, value: Comparable) -> bool:
339 Delete an item from the tree and preserve the BST property.
342 value: the value of the node to be deleted.
345 True if the value was found and its associated node was
346 successfully deleted and False otherwise.
348 >>> t = BinarySearchTree()
365 >>> for value in t.iterate_inorder():
375 >>> del t[22] # Note: bool result is discarded
377 >>> for value in t.iterate_inorder():
386 >>> t.__delitem__(13)
388 >>> for value in t.iterate_inorder():
396 >>> t.__delitem__(75)
398 >>> for value in t.iterate_inorder():
410 >>> t.__delitem__(99)
414 if self.root is not None:
415 ret = self._delete(value, None, self.root)
423 def _on_delete(self, parent: Optional[Node], deleted: Node) -> None:
424 """This is called just after deleted was deleted from the tree"""
427 def _delete(self, value: Comparable, parent: Optional[Node], node: Node) -> bool:
429 if node.value == value:
431 # Deleting a leaf node
432 if node.left is None and node.right is None:
433 if parent is not None:
434 if parent.left == node:
437 assert parent.right == node
439 self._on_delete(parent, node)
442 # Node only has a right.
443 elif node.left is None:
444 assert node.right is not None
445 if parent is not None:
446 if parent.left == node:
447 parent.left = node.right
449 assert parent.right == node
450 parent.right = node.right
451 self._on_delete(parent, node)
454 # Node only has a left.
455 elif node.right is None:
456 assert node.left is not None
457 if parent is not None:
458 if parent.left == node:
459 parent.left = node.left
461 assert parent.right == node
462 parent.right = node.left
463 self._on_delete(parent, node)
466 # Node has both a left and right.
468 assert node.left is not None and node.right is not None
469 descendent = node.right
470 while descendent.left is not None:
471 descendent = descendent.left
472 node.value = descendent.value
473 return self._delete(node.value, node, node.right)
474 elif value < node.value and node.left is not None:
475 return self._delete(value, node, node.left)
476 elif value > node.value and node.right is not None:
477 return self._delete(value, node, node.right)
483 The count of items in the tree.
485 >>> t = BinarySearchTree()
491 >>> t.__delitem__(50)
507 def __contains__(self, value: Comparable) -> bool:
510 True if the item is in the tree; False otherwise.
512 return self.__getitem__(value) is not None
514 def _iterate_preorder(self, node: Node):
516 if node.left is not None:
517 yield from self._iterate_preorder(node.left)
518 if node.right is not None:
519 yield from self._iterate_preorder(node.right)
521 def _iterate_inorder(self, node: Node):
522 if node.left is not None:
523 yield from self._iterate_inorder(node.left)
525 if node.right is not None:
526 yield from self._iterate_inorder(node.right)
528 def _iterate_postorder(self, node: Node):
529 if node.left is not None:
530 yield from self._iterate_postorder(node.left)
531 if node.right is not None:
532 yield from self._iterate_postorder(node.right)
535 def iterate_preorder(self):
538 A Generator that yields the tree's items in a
539 preorder traversal sequence.
541 >>> t = BinarySearchTree()
549 >>> for value in t.iterate_preorder():
559 if self.root is not None:
560 yield from self._iterate_preorder(self.root)
562 def iterate_inorder(self):
565 A Generator that yield the tree's items in a preorder
568 >>> t = BinarySearchTree()
585 >>> for value in t.iterate_inorder():
596 if self.root is not None:
597 yield from self._iterate_inorder(self.root)
599 def iterate_postorder(self):
602 A Generator that yield the tree's items in a preorder
605 >>> t = BinarySearchTree()
613 >>> for value in t.iterate_postorder():
623 if self.root is not None:
624 yield from self._iterate_postorder(self.root)
626 def _iterate_leaves(self, node: Node):
627 if node.left is not None:
628 yield from self._iterate_leaves(node.left)
629 if node.right is not None:
630 yield from self._iterate_leaves(node.right)
631 if node.left is None and node.right is None:
634 def iterate_leaves(self):
637 A Gemerator that yielde only the leaf nodes in the
640 >>> t = BinarySearchTree()
648 >>> for value in t.iterate_leaves():
654 if self.root is not None:
655 yield from self._iterate_leaves(self.root)
657 def _iterate_by_depth(self, node: Node, depth: int):
662 if node.left is not None:
663 yield from self._iterate_by_depth(node.left, depth - 1)
664 if node.right is not None:
665 yield from self._iterate_by_depth(node.right, depth - 1)
667 def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
670 depth: the desired depth
673 A Generator that yields nodes at the prescribed depth in
676 >>> t = BinarySearchTree()
684 >>> for value in t.iterate_nodes_by_depth(2):
689 >>> for value in t.iterate_nodes_by_depth(3):
694 if self.root is not None:
695 yield from self._iterate_by_depth(self.root, depth)
697 def get_next_node(self, node: Node) -> Optional[Node]:
700 node: the node whose next greater successor is desired
703 Given a tree node, returns the next greater node in the tree.
704 If the given node is the greatest node in the tree, returns None.
706 >>> t = BinarySearchTree()
724 >>> t.get_next_node(n).value
728 >>> t.get_next_node(n).value
732 >>> t.get_next_node(n) is None
736 if node.right is not None:
738 while x.left is not None:
742 path = self.parent_path(node)
743 assert path[-1] is not None
744 assert path[-1] == node
747 for ancestor in path:
748 assert ancestor is not None
749 if node != ancestor.right:
754 def get_nodes_in_range_inclusive(self, lower: Comparable, upper: Comparable):
756 >>> t = BinarySearchTree()
773 >>> for node in t.get_nodes_in_range_inclusive(21, 74):
774 ... print(node.value)
781 node: Optional[Node] = self._find_lowest_node_greater_than_or_equal_to(
785 if lower <= node.value <= upper:
787 node = self.get_next_node(node)
789 def _depth(self, node: Node, sofar: int) -> int:
790 depth_left = sofar + 1
791 depth_right = sofar + 1
792 if node.left is not None:
793 depth_left = self._depth(node.left, sofar + 1)
794 if node.right is not None:
795 depth_right = self._depth(node.right, sofar + 1)
796 return max(depth_left, depth_right)
798 def depth(self) -> int:
801 The max height (depth) of the tree in plies (edge distance
804 >>> t = BinarySearchTree()
826 if self.root is None:
828 return self._depth(self.root, 0)
830 def height(self) -> int:
831 """Returns the height (i.e. max depth) of the tree"""
838 node: Optional[Node],
839 has_right_sibling: bool,
842 viz = f"\n{padding}{pointer}{node.value}"
843 if has_right_sibling:
848 pointer_right = "└──"
849 if node.right is not None:
854 viz += self.repr_traverse(
855 padding, pointer_left, node.left, node.right is not None
857 viz += self.repr_traverse(padding, pointer_right, node.right, False)
864 An ASCII string representation of the tree.
866 >>> t = BinarySearchTree()
883 if self.root is None:
886 ret = f"{self.root.value}"
887 pointer_right = "└──"
888 if self.root.right is None:
893 ret += self.repr_traverse(
894 "", pointer_left, self.root.left, self.root.left is not None
896 ret += self.repr_traverse("", pointer_right, self.root.right, False)
900 if __name__ == "__main__":