3 # © Copyright 2021-2023, Scott Gasch
5 """A binary search tree implementation."""
7 from abc import abstractmethod
8 from typing import Any, Generator, List, Optional, Protocol
11 class Comparable(Protocol):
12 """Anything that implements basic comparison methods such that it
13 can be compared to other instances of the same type.
15 Check out :meth:`functools.total_ordering`
16 (https://docs.python.org/3/library/functools.html#functools.total_ordering)
17 for an easy way to make your type comparable.
21 def __lt__(self, other: Any) -> bool:
25 def __le__(self, other: Any) -> bool:
29 def __eq__(self, other: Any) -> bool:
34 def __init__(self, value: Comparable) -> None:
35 """A BST node. Just a left and right reference along with a
36 value. Note that value can be anything as long as it
37 is :class:`Comparable` with other instances of itself.
40 value: a reference to the value of the node. Must be
41 :class:`Comparable` to other values.
44 self.left: Optional[Node] = None
45 self.right: Optional[Node] = None
46 self.value: Comparable = value
49 class BinarySearchTree(object):
55 def get_root(self) -> Optional[Node]:
63 def _on_insert(self, parent: Optional[Node], new: Node) -> None:
64 """This is called immediately _after_ a new node is inserted."""
67 def insert(self, value: Comparable) -> None:
69 Insert something into the tree in :math:`O(log_2 n)` time.
72 value: the value to be inserted.
74 >>> t = BinarySearchTree()
81 >>> t.get_root().value
86 self.root = Node(value)
88 self._on_insert(None, self.root)
90 self._insert(value, self.root)
92 def _insert(self, value: Comparable, node: Node):
93 """Insertion helper"""
94 if value < node.value:
95 if node.left is not None:
96 self._insert(value, node.left)
98 node.left = Node(value)
100 self._on_insert(node, node.left)
102 if node.right is not None:
103 self._insert(value, node.right)
105 node.right = Node(value)
107 self._on_insert(node, node.right)
109 def __getitem__(self, value: Comparable) -> Optional[Node]:
111 Find an item in the tree and return its Node in
112 :math:`O(log_2 n)` time. Returns None if the item is not in
115 >>> t = BinarySearchTree()
127 if self.root is not None:
128 return self._find_exact(value, self.root)
131 def _find_exact(self, target: Comparable, node: Node) -> Optional[Node]:
132 """Recursively traverse the tree looking for a node with the
133 target value. Return that node if it exists, otherwise return
136 if target == node.value:
138 elif target < node.value and node.left is not None:
139 return self._find_exact(target, node.left)
140 elif target > node.value and node.right is not None:
141 return self._find_exact(target, node.right)
144 def _find_lowest_node_less_than_or_equal_to(
145 self, target: Comparable, node: Optional[Node]
147 """Find helper that returns the lowest node that is less
148 than or equal to the target value. Returns None if target is
149 lower than the lowest node in the tree.
151 >>> t = BinarySearchTree()
168 >>> t._find_lowest_node_less_than_or_equal_to(48, t.root).value
170 >>> t._find_lowest_node_less_than_or_equal_to(55, t.root).value
172 >>> t._find_lowest_node_less_than_or_equal_to(100, t.root).value
174 >>> t._find_lowest_node_less_than_or_equal_to(24, t.root).value
176 >>> t._find_lowest_node_less_than_or_equal_to(20, t.root).value
178 >>> t._find_lowest_node_less_than_or_equal_to(72, t.root).value
180 >>> t._find_lowest_node_less_than_or_equal_to(78, t.root).value
182 >>> t._find_lowest_node_less_than_or_equal_to(12, t.root) is None
190 if target == node.value:
193 elif target > node.value:
194 if below := self._find_lowest_node_less_than_or_equal_to(
202 return self._find_lowest_node_less_than_or_equal_to(target, node.left)
204 def _find_lowest_node_greater_than_or_equal_to(
205 self, target: Comparable, node: Optional[Node]
207 """Find helper that returns the lowest node that is greater
208 than or equal to the target value. Returns None if target is
209 higher than the greatest node in the tree.
211 >>> t = BinarySearchTree()
228 >>> t._find_lowest_node_greater_than_or_equal_to(48, t.root).value
230 >>> t._find_lowest_node_greater_than_or_equal_to(55, t.root).value
232 >>> t._find_lowest_node_greater_than_or_equal_to(1, t.root).value
234 >>> t._find_lowest_node_greater_than_or_equal_to(24, t.root).value
236 >>> t._find_lowest_node_greater_than_or_equal_to(20, t.root).value
238 >>> t._find_lowest_node_greater_than_or_equal_to(72, t.root).value
240 >>> t._find_lowest_node_greater_than_or_equal_to(78, t.root).value
242 >>> t._find_lowest_node_greater_than_or_equal_to(95, t.root) is None
250 if target == node.value:
253 elif target > node.value:
254 return self._find_lowest_node_greater_than_or_equal_to(target, node.right)
256 # If target < this node's value, either this node is the
257 # answer or the answer is in this node's left subtree.
259 if below := self._find_lowest_node_greater_than_or_equal_to(
267 self, current: Optional[Node], target: Node
268 ) -> List[Optional[Node]]:
269 """Internal helper"""
272 ret: List[Optional[Node]] = [current]
273 if target.value == current.value:
275 elif target.value < current.value:
276 ret.extend(self._parent_path(current.left, target))
279 assert target.value > current.value
280 ret.extend(self._parent_path(current.right, target))
283 def parent_path(self, node: Node) -> List[Optional[Node]]:
284 """Get a node's parent path in :math:`O(log_2 n)` time.
287 node: the node whose parent path should be returned.
290 a list of nodes representing the path from
291 the tree's root to the given node.
295 If the node does not exist in the tree, the last element
296 on the path will be None but the path will indicate the
297 ancestor path of that node were it to be inserted.
299 >>> t = BinarySearchTree()
317 >>> for x in t.parent_path(n):
325 >>> for x in t.parent_path(n):
326 ... if x is not None:
336 return self._parent_path(self.root, node)
338 def __delitem__(self, value: Comparable) -> bool:
340 Delete an item from the tree and preserve the BST property in
341 :math:`O(log_2 n) time`.
344 value: the value of the node to be deleted.
347 True if the value was found and its associated node was
348 successfully deleted and False otherwise.
350 >>> t = BinarySearchTree()
367 >>> for value in t.iterate_inorder():
377 >>> del t[22] # Note: bool result is discarded
379 >>> for value in t.iterate_inorder():
388 >>> t.__delitem__(13)
390 >>> for value in t.iterate_inorder():
398 >>> t.__delitem__(75)
400 >>> for value in t.iterate_inorder():
412 >>> t.__delitem__(85)
415 >>> t.__delitem__(99)
419 if self.root is not None:
420 ret = self._delete(value, None, self.root)
428 def _on_delete(self, parent: Optional[Node], deleted: Node) -> None:
429 """This is called just after deleted was deleted from the tree"""
432 def _delete(self, value: Comparable, parent: Optional[Node], node: Node) -> bool:
434 if node.value == value:
436 # Deleting a leaf node
437 if node.left is None and node.right is None:
438 if parent is not None:
439 if parent.left == node:
442 assert parent.right == node
444 self._on_delete(parent, node)
447 # Node only has a right.
448 elif node.left is None:
449 assert node.right is not None
450 if parent is not None:
451 if parent.left == node:
452 parent.left = node.right
454 assert parent.right == node
455 parent.right = node.right
456 self._on_delete(parent, node)
459 # Node only has a left.
460 elif node.right is None:
461 assert node.left is not None
462 if parent is not None:
463 if parent.left == node:
464 parent.left = node.left
466 assert parent.right == node
467 parent.right = node.left
468 self._on_delete(parent, node)
471 # Node has both a left and right; get the successor node
472 # to this one and put it here (deleting the successor's
473 # old node). Because these operations are happening only
474 # in the subtree underneath of node, I'm still calling
475 # this delete an O(log_2 n) operation in the docs.
477 assert node.left is not None and node.right is not None
478 successor = self.get_next_node(node)
479 assert successor is not None
480 node.value = successor.value
481 return self._delete(node.value, node, node.right)
483 elif value < node.value and node.left is not None:
484 return self._delete(value, node, node.left)
485 elif value > node.value and node.right is not None:
486 return self._delete(value, node, node.right)
492 The count of items in the tree in :math:`O(1)` time.
494 >>> t = BinarySearchTree()
500 >>> t.__delitem__(50)
516 def __contains__(self, value: Comparable) -> bool:
519 True if the item is in the tree; False otherwise.
521 return self.__getitem__(value) is not None
523 def _iterate_preorder(self, node: Node):
525 if node.left is not None:
526 yield from self._iterate_preorder(node.left)
527 if node.right is not None:
528 yield from self._iterate_preorder(node.right)
530 def _iterate_inorder(self, node: Node):
531 if node.left is not None:
532 yield from self._iterate_inorder(node.left)
534 if node.right is not None:
535 yield from self._iterate_inorder(node.right)
537 def _iterate_postorder(self, node: Node):
538 if node.left is not None:
539 yield from self._iterate_postorder(node.left)
540 if node.right is not None:
541 yield from self._iterate_postorder(node.right)
544 def iterate_preorder(self):
547 A Generator that yields the tree's items in a
548 preorder traversal sequence.
550 >>> t = BinarySearchTree()
558 >>> for value in t.iterate_preorder():
568 if self.root is not None:
569 yield from self._iterate_preorder(self.root)
571 def iterate_inorder(self):
574 A Generator that yield the tree's items in a preorder
577 >>> t = BinarySearchTree()
594 >>> for value in t.iterate_inorder():
605 if self.root is not None:
606 yield from self._iterate_inorder(self.root)
608 def iterate_postorder(self):
611 A Generator that yield the tree's items in a preorder
614 >>> t = BinarySearchTree()
622 >>> for value in t.iterate_postorder():
632 if self.root is not None:
633 yield from self._iterate_postorder(self.root)
635 def _iterate_leaves(self, node: Node):
636 if node.left is not None:
637 yield from self._iterate_leaves(node.left)
638 if node.right is not None:
639 yield from self._iterate_leaves(node.right)
640 if node.left is None and node.right is None:
643 def iterate_leaves(self):
646 A Generator that yields only the leaf nodes in the
649 >>> t = BinarySearchTree()
657 >>> for value in t.iterate_leaves():
663 if self.root is not None:
664 yield from self._iterate_leaves(self.root)
666 def _iterate_by_depth(self, node: Node, depth: int):
671 if node.left is not None:
672 yield from self._iterate_by_depth(node.left, depth - 1)
673 if node.right is not None:
674 yield from self._iterate_by_depth(node.right, depth - 1)
676 def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
679 depth: the desired depth
682 A Generator that yields nodes at the prescribed depth in
685 >>> t = BinarySearchTree()
693 >>> for value in t.iterate_nodes_by_depth(2):
698 >>> for value in t.iterate_nodes_by_depth(3):
703 if self.root is not None:
704 yield from self._iterate_by_depth(self.root, depth)
706 def get_next_node(self, node: Node) -> Optional[Node]:
709 node: the node whose next greater successor is desired
712 Given a tree node, returns the next greater node in the tree.
713 If the given node is the greatest node in the tree, returns None.
715 >>> t = BinarySearchTree()
733 >>> t.get_next_node(n).value
737 >>> t.get_next_node(n).value
741 >>> t.get_next_node(n) is None
745 if node.right is not None:
747 while x.left is not None:
751 path = self.parent_path(node)
752 assert path[-1] is not None
753 assert path[-1] == node
756 for ancestor in path:
757 assert ancestor is not None
758 if node != ancestor.right:
763 def get_nodes_in_range_inclusive(
764 self, lower: Comparable, upper: Comparable
765 ) -> Generator[Node, None, None]:
768 lower: the lower bound of the desired range.
769 upper: the upper bound of the desired range.
772 Generates a sequence of nodes in the desired range.
774 >>> t = BinarySearchTree()
791 >>> for node in t.get_nodes_in_range_inclusive(21, 74):
792 ... print(node.value)
799 node: Optional[Node] = self._find_lowest_node_greater_than_or_equal_to(
803 if lower <= node.value <= upper:
805 node = self.get_next_node(node)
807 def _depth(self, node: Node, sofar: int) -> int:
808 depth_left = sofar + 1
809 depth_right = sofar + 1
810 if node.left is not None:
811 depth_left = self._depth(node.left, sofar + 1)
812 if node.right is not None:
813 depth_right = self._depth(node.right, sofar + 1)
814 return max(depth_left, depth_right)
816 def depth(self) -> int:
819 The max height (depth) of the tree in plies (edge distance
820 from root) in :math:`O(log_2 n)` time.
822 >>> t = BinarySearchTree()
844 if self.root is None:
846 return self._depth(self.root, 0)
848 def height(self) -> int:
849 """Returns the height (i.e. max depth) of the tree"""
856 node: Optional[Node],
857 has_right_sibling: bool,
860 viz = f"\n{padding}{pointer}{node.value}"
861 if has_right_sibling:
866 pointer_right = "└──"
867 if node.right is not None:
872 viz += self.repr_traverse(
873 padding, pointer_left, node.left, node.right is not None
875 viz += self.repr_traverse(padding, pointer_right, node.right, False)
882 An ASCII string representation of the tree.
884 >>> t = BinarySearchTree()
901 if self.root is None:
904 ret = f"{self.root.value}"
905 pointer_right = "└──"
906 if self.root.right is None:
911 ret += self.repr_traverse(
912 "", pointer_left, self.root.left, self.root.left is not None
914 ret += self.repr_traverse("", pointer_right, self.root.right, False)
918 if __name__ == "__main__":