3 # © Copyright 2021-2023, Scott Gasch
5 """A binary search tree implementation."""
7 from abc import abstractmethod
8 from typing import Any, Generator, List, Optional, Protocol
11 class Comparable(Protocol):
13 def __lt__(self, other: Any) -> bool:
17 def __le__(self, other: Any) -> bool:
21 def __eq__(self, other: Any) -> bool:
26 def __init__(self, value: Comparable) -> None:
27 """A BST node. Note that value can be anything as long as it
28 is comparable with other instances of itself. Check out
29 :meth:`functools.total_ordering`
30 (https://docs.python.org/3/library/functools.html#functools.total_ordering)
33 value: a reference to the value of the node.
36 self.left: Optional[Node] = None
37 self.right: Optional[Node] = None
38 self.value: Comparable = value
41 class BinarySearchTree(object):
47 def get_root(self) -> Optional[Node]:
55 def _on_insert(self, parent: Optional[Node], new: Node) -> None:
56 """This is called immediately _after_ a new node is inserted."""
59 def insert(self, value: Comparable) -> None:
61 Insert something into the tree.
64 value: the value to be inserted.
66 >>> t = BinarySearchTree()
73 >>> t.get_root().value
78 self.root = Node(value)
80 self._on_insert(None, self.root)
82 self._insert(value, self.root)
84 def _insert(self, value: Comparable, node: Node):
85 """Insertion helper"""
86 if value < node.value:
87 if node.left is not None:
88 self._insert(value, node.left)
90 node.left = Node(value)
92 self._on_insert(node, node.left)
94 if node.right is not None:
95 self._insert(value, node.right)
97 node.right = Node(value)
99 self._on_insert(node, node.right)
101 def __getitem__(self, value: Comparable) -> Optional[Node]:
103 Find an item in the tree and return its Node. Returns
104 None if the item is not in the tree.
106 >>> t = BinarySearchTree()
118 if self.root is not None:
119 return self._find_exact(value, self.root)
122 def _find_exact(self, target: Comparable, node: Node) -> Optional[Node]:
123 """Recursively traverse the tree looking for a node with the
124 target value. Return that node if it exists, otherwise return
127 if target == node.value:
129 elif target < node.value and node.left is not None:
130 return self._find_exact(target, node.left)
131 elif target > node.value and node.right is not None:
132 return self._find_exact(target, node.right)
135 def _find_lowest_node_less_than_or_equal_to(
136 self, target: Comparable, node: Optional[Node]
138 """Find helper that returns the lowest node that is less
139 than or equal to the target value. Returns None if target is
140 lower than the lowest node in the tree.
142 >>> t = BinarySearchTree()
159 >>> t._find_lowest_node_less_than_or_equal_to(48, t.root).value
161 >>> t._find_lowest_node_less_than_or_equal_to(55, t.root).value
163 >>> t._find_lowest_node_less_than_or_equal_to(100, t.root).value
165 >>> t._find_lowest_node_less_than_or_equal_to(24, t.root).value
167 >>> t._find_lowest_node_less_than_or_equal_to(20, t.root).value
169 >>> t._find_lowest_node_less_than_or_equal_to(72, t.root).value
171 >>> t._find_lowest_node_less_than_or_equal_to(78, t.root).value
173 >>> t._find_lowest_node_less_than_or_equal_to(12, t.root) is None
181 if target == node.value:
184 elif target > node.value:
185 if below := self._find_lowest_node_less_than_or_equal_to(
193 return self._find_lowest_node_less_than_or_equal_to(target, node.left)
195 def _find_lowest_node_greater_than_or_equal_to(
196 self, target: Comparable, node: Optional[Node]
198 """Find helper that returns the lowest node that is greater
199 than or equal to the target value. Returns None if target is
200 higher than the greatest node in the tree.
202 >>> t = BinarySearchTree()
219 >>> t._find_lowest_node_greater_than_or_equal_to(48, t.root).value
221 >>> t._find_lowest_node_greater_than_or_equal_to(55, t.root).value
223 >>> t._find_lowest_node_greater_than_or_equal_to(1, t.root).value
225 >>> t._find_lowest_node_greater_than_or_equal_to(24, t.root).value
227 >>> t._find_lowest_node_greater_than_or_equal_to(20, t.root).value
229 >>> t._find_lowest_node_greater_than_or_equal_to(72, t.root).value
231 >>> t._find_lowest_node_greater_than_or_equal_to(78, t.root).value
233 >>> t._find_lowest_node_greater_than_or_equal_to(95, t.root) is None
241 if target == node.value:
244 elif target > node.value:
245 return self._find_lowest_node_greater_than_or_equal_to(target, node.right)
247 # If target < this node's value, either this node is the
248 # answer or the answer is in this node's left subtree.
250 if below := self._find_lowest_node_greater_than_or_equal_to(
258 self, current: Optional[Node], target: Node
259 ) -> List[Optional[Node]]:
260 """Internal helper"""
263 ret: List[Optional[Node]] = [current]
264 if target.value == current.value:
266 elif target.value < current.value:
267 ret.extend(self._parent_path(current.left, target))
270 assert target.value > current.value
271 ret.extend(self._parent_path(current.right, target))
274 def parent_path(self, node: Node) -> List[Optional[Node]]:
275 """Get a node's parent path.
278 node: the node to check
281 a list of nodes representing the path from
282 the tree's root to the node.
286 If the node does not exist in the tree, the last element
287 on the path will be None but the path will indicate the
288 ancestor path of that node were it to be inserted.
290 >>> t = BinarySearchTree()
308 >>> for x in t.parent_path(n):
316 >>> for x in t.parent_path(n):
317 ... if x is not None:
327 return self._parent_path(self.root, node)
329 def __delitem__(self, value: Comparable) -> bool:
331 Delete an item from the tree and preserve the BST property.
334 value: the value of the node to be deleted.
337 True if the value was found and its associated node was
338 successfully deleted and False otherwise.
340 >>> t = BinarySearchTree()
357 >>> for value in t.iterate_inorder():
367 >>> del t[22] # Note: bool result is discarded
369 >>> for value in t.iterate_inorder():
378 >>> t.__delitem__(13)
380 >>> for value in t.iterate_inorder():
388 >>> t.__delitem__(75)
390 >>> for value in t.iterate_inorder():
402 >>> t.__delitem__(99)
406 if self.root is not None:
407 ret = self._delete(value, None, self.root)
415 def _on_delete(self, parent: Optional[Node], deleted: Node) -> None:
416 """This is called just after deleted was deleted from the tree"""
419 def _delete(self, value: Comparable, parent: Optional[Node], node: Node) -> bool:
421 if node.value == value:
423 # Deleting a leaf node
424 if node.left is None and node.right is None:
425 if parent is not None:
426 if parent.left == node:
429 assert parent.right == node
431 self._on_delete(parent, node)
434 # Node only has a right.
435 elif node.left is None:
436 assert node.right is not None
437 if parent is not None:
438 if parent.left == node:
439 parent.left = node.right
441 assert parent.right == node
442 parent.right = node.right
443 self._on_delete(parent, node)
446 # Node only has a left.
447 elif node.right is None:
448 assert node.left is not None
449 if parent is not None:
450 if parent.left == node:
451 parent.left = node.left
453 assert parent.right == node
454 parent.right = node.left
455 self._on_delete(parent, node)
458 # Node has both a left and right.
460 assert node.left is not None and node.right is not None
461 descendent = node.right
462 while descendent.left is not None:
463 descendent = descendent.left
464 node.value = descendent.value
465 return self._delete(node.value, node, node.right)
466 elif value < node.value and node.left is not None:
467 return self._delete(value, node, node.left)
468 elif value > node.value and node.right is not None:
469 return self._delete(value, node, node.right)
475 The count of items in the tree.
477 >>> t = BinarySearchTree()
483 >>> t.__delitem__(50)
499 def __contains__(self, value: Comparable) -> bool:
502 True if the item is in the tree; False otherwise.
504 return self.__getitem__(value) is not None
506 def _iterate_preorder(self, node: Node):
508 if node.left is not None:
509 yield from self._iterate_preorder(node.left)
510 if node.right is not None:
511 yield from self._iterate_preorder(node.right)
513 def _iterate_inorder(self, node: Node):
514 if node.left is not None:
515 yield from self._iterate_inorder(node.left)
517 if node.right is not None:
518 yield from self._iterate_inorder(node.right)
520 def _iterate_postorder(self, node: Node):
521 if node.left is not None:
522 yield from self._iterate_postorder(node.left)
523 if node.right is not None:
524 yield from self._iterate_postorder(node.right)
527 def iterate_preorder(self):
530 A Generator that yields the tree's items in a
531 preorder traversal sequence.
533 >>> t = BinarySearchTree()
541 >>> for value in t.iterate_preorder():
551 if self.root is not None:
552 yield from self._iterate_preorder(self.root)
554 def iterate_inorder(self):
557 A Generator that yield the tree's items in a preorder
560 >>> t = BinarySearchTree()
577 >>> for value in t.iterate_inorder():
588 if self.root is not None:
589 yield from self._iterate_inorder(self.root)
591 def iterate_postorder(self):
594 A Generator that yield the tree's items in a preorder
597 >>> t = BinarySearchTree()
605 >>> for value in t.iterate_postorder():
615 if self.root is not None:
616 yield from self._iterate_postorder(self.root)
618 def _iterate_leaves(self, node: Node):
619 if node.left is not None:
620 yield from self._iterate_leaves(node.left)
621 if node.right is not None:
622 yield from self._iterate_leaves(node.right)
623 if node.left is None and node.right is None:
626 def iterate_leaves(self):
629 A Gemerator that yielde only the leaf nodes in the
632 >>> t = BinarySearchTree()
640 >>> for value in t.iterate_leaves():
646 if self.root is not None:
647 yield from self._iterate_leaves(self.root)
649 def _iterate_by_depth(self, node: Node, depth: int):
654 if node.left is not None:
655 yield from self._iterate_by_depth(node.left, depth - 1)
656 if node.right is not None:
657 yield from self._iterate_by_depth(node.right, depth - 1)
659 def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
662 depth: the desired depth
665 A Generator that yields nodes at the prescribed depth in
668 >>> t = BinarySearchTree()
676 >>> for value in t.iterate_nodes_by_depth(2):
681 >>> for value in t.iterate_nodes_by_depth(3):
686 if self.root is not None:
687 yield from self._iterate_by_depth(self.root, depth)
689 def get_next_node(self, node: Node) -> Optional[Node]:
692 node: the node whose next greater successor is desired
695 Given a tree node, returns the next greater node in the tree.
696 If the given node is the greatest node in the tree, returns None.
698 >>> t = BinarySearchTree()
716 >>> t.get_next_node(n).value
720 >>> t.get_next_node(n).value
724 >>> t.get_next_node(n) is None
728 if node.right is not None:
730 while x.left is not None:
734 path = self.parent_path(node)
735 assert path[-1] is not None
736 assert path[-1] == node
739 for ancestor in path:
740 assert ancestor is not None
741 if node != ancestor.right:
746 def get_nodes_in_range_inclusive(self, lower: Comparable, upper: Comparable):
748 >>> t = BinarySearchTree()
765 >>> for node in t.get_nodes_in_range_inclusive(21, 74):
766 ... print(node.value)
773 node: Optional[Node] = self._find_lowest_node_greater_than_or_equal_to(
777 if lower <= node.value <= upper:
779 node = self.get_next_node(node)
781 def _depth(self, node: Node, sofar: int) -> int:
782 depth_left = sofar + 1
783 depth_right = sofar + 1
784 if node.left is not None:
785 depth_left = self._depth(node.left, sofar + 1)
786 if node.right is not None:
787 depth_right = self._depth(node.right, sofar + 1)
788 return max(depth_left, depth_right)
790 def depth(self) -> int:
793 The max height (depth) of the tree in plies (edge distance
796 >>> t = BinarySearchTree()
818 if self.root is None:
820 return self._depth(self.root, 0)
822 def height(self) -> int:
823 """Returns the height (i.e. max depth) of the tree"""
830 node: Optional[Node],
831 has_right_sibling: bool,
834 viz = f"\n{padding}{pointer}{node.value}"
835 if has_right_sibling:
840 pointer_right = "└──"
841 if node.right is not None:
846 viz += self.repr_traverse(
847 padding, pointer_left, node.left, node.right is not None
849 viz += self.repr_traverse(padding, pointer_right, node.right, False)
856 An ASCII string representation of the tree.
858 >>> t = BinarySearchTree()
875 if self.root is None:
878 ret = f"{self.root.value}"
879 pointer_right = "└──"
880 if self.root.right is None:
885 ret += self.repr_traverse(
886 "", pointer_left, self.root.left, self.root.left is not None
888 ret += self.repr_traverse("", pointer_right, self.root.right, False)
892 if __name__ == "__main__":