3 # © Copyright 2021-2023, Scott Gasch
5 """A binary search tree implementation."""
7 from abc import ABCMeta, abstractmethod
8 from typing import Any, Generator, List, Optional, TypeVar
11 class Comparable(metaclass=ABCMeta):
13 def __lt__(self, other: Any) -> bool:
17 def __le__(self, other: Any) -> bool:
21 def __eq__(self, other: Any) -> bool:
25 ComparableNodeValue = TypeVar('ComparableNodeValue', bound=Comparable)
29 def __init__(self, value: ComparableNodeValue) -> None:
30 """A BST node. Note that value can be anything as long as it
31 is comparable with other instances of itself. Check out
32 :meth:`functools.total_ordering`
33 (https://docs.python.org/3/library/functools.html#functools.total_ordering)
36 value: a reference to the value of the node.
39 self.left: Optional[Node] = None
40 self.right: Optional[Node] = None
41 self.value: ComparableNodeValue = value
44 class BinarySearchTree(object):
50 def get_root(self) -> Optional[Node]:
58 def _on_insert(self, parent: Optional[Node], new: Node) -> None:
59 """This is called immediately _after_ a new node is inserted."""
62 def insert(self, value: ComparableNodeValue) -> None:
64 Insert something into the tree.
67 value: the value to be inserted.
69 >>> t = BinarySearchTree()
76 >>> t.get_root().value
81 self.root = Node(value)
83 self._on_insert(None, self.root)
85 self._insert(value, self.root)
87 def _insert(self, value: ComparableNodeValue, node: Node):
88 """Insertion helper"""
89 if value < node.value:
90 if node.left is not None:
91 self._insert(value, node.left)
93 node.left = Node(value)
95 self._on_insert(node, node.left)
97 if node.right is not None:
98 self._insert(value, node.right)
100 node.right = Node(value)
102 self._on_insert(node, node.right)
104 def __getitem__(self, value: ComparableNodeValue) -> Optional[Node]:
106 Find an item in the tree and return its Node. Returns
107 None if the item is not in the tree.
109 >>> t = BinarySearchTree()
121 if self.root is not None:
122 return self._find_exact(value, self.root)
125 def _find_exact(self, target: ComparableNodeValue, node: Node) -> Optional[Node]:
126 """Recursively traverse the tree looking for a node with the
127 target value. Return that node if it exists, otherwise return
130 if target == node.value:
132 elif target < node.value and node.left is not None:
133 return self._find_exact(target, node.left)
134 elif target > node.value and node.right is not None:
135 return self._find_exact(target, node.right)
138 def _find_lowest_node_less_than_or_equal_to(
139 self, target: ComparableNodeValue, node: Optional[Node]
141 """Find helper that returns the lowest node that is less
142 than or equal to the target value. Returns None if target is
143 lower than the lowest node in the tree.
145 >>> t = BinarySearchTree()
162 >>> t._find_lowest_node_less_than_or_equal_to(48, t.root).value
164 >>> t._find_lowest_node_less_than_or_equal_to(55, t.root).value
166 >>> t._find_lowest_node_less_than_or_equal_to(100, t.root).value
168 >>> t._find_lowest_node_less_than_or_equal_to(24, t.root).value
170 >>> t._find_lowest_node_less_than_or_equal_to(20, t.root).value
172 >>> t._find_lowest_node_less_than_or_equal_to(72, t.root).value
174 >>> t._find_lowest_node_less_than_or_equal_to(78, t.root).value
176 >>> t._find_lowest_node_less_than_or_equal_to(12, t.root) is None
184 if target == node.value:
187 elif target > node.value:
188 if below := self._find_lowest_node_less_than_or_equal_to(
196 return self._find_lowest_node_less_than_or_equal_to(target, node.left)
198 def _find_lowest_node_greater_than_or_equal_to(
199 self, target: ComparableNodeValue, node: Optional[Node]
201 """Find helper that returns the lowest node that is greater
202 than or equal to the target value. Returns None if target is
203 higher than the greatest node in the tree.
205 >>> t = BinarySearchTree()
222 >>> t._find_lowest_node_greater_than_or_equal_to(48, t.root).value
224 >>> t._find_lowest_node_greater_than_or_equal_to(55, t.root).value
226 >>> t._find_lowest_node_greater_than_or_equal_to(1, t.root).value
228 >>> t._find_lowest_node_greater_than_or_equal_to(24, t.root).value
230 >>> t._find_lowest_node_greater_than_or_equal_to(20, t.root).value
232 >>> t._find_lowest_node_greater_than_or_equal_to(72, t.root).value
234 >>> t._find_lowest_node_greater_than_or_equal_to(78, t.root).value
236 >>> t._find_lowest_node_greater_than_or_equal_to(95, t.root) is None
244 if target == node.value:
247 elif target > node.value:
248 return self._find_lowest_node_greater_than_or_equal_to(target, node.right)
250 # If target < this node's value, either this node is the
251 # answer or the answer is in this node's left subtree.
253 if below := self._find_lowest_node_greater_than_or_equal_to(
261 self, current: Optional[Node], target: Node
262 ) -> List[Optional[Node]]:
263 """Internal helper"""
266 ret: List[Optional[Node]] = [current]
267 if target.value == current.value:
269 elif target.value < current.value:
270 ret.extend(self._parent_path(current.left, target))
273 assert target.value > current.value
274 ret.extend(self._parent_path(current.right, target))
277 def parent_path(self, node: Node) -> List[Optional[Node]]:
278 """Get a node's parent path.
281 node: the node to check
284 a list of nodes representing the path from
285 the tree's root to the node.
289 If the node does not exist in the tree, the last element
290 on the path will be None but the path will indicate the
291 ancestor path of that node were it to be inserted.
293 >>> t = BinarySearchTree()
311 >>> for x in t.parent_path(n):
319 >>> for x in t.parent_path(n):
320 ... if x is not None:
330 return self._parent_path(self.root, node)
332 def __delitem__(self, value: ComparableNodeValue) -> bool:
334 Delete an item from the tree and preserve the BST property.
337 value: the value of the node to be deleted.
340 True if the value was found and its associated node was
341 successfully deleted and False otherwise.
343 >>> t = BinarySearchTree()
360 >>> for value in t.iterate_inorder():
370 >>> del t[22] # Note: bool result is discarded
372 >>> for value in t.iterate_inorder():
381 >>> t.__delitem__(13)
383 >>> for value in t.iterate_inorder():
391 >>> t.__delitem__(75)
393 >>> for value in t.iterate_inorder():
405 >>> t.__delitem__(99)
409 if self.root is not None:
410 ret = self._delete(value, None, self.root)
418 def _on_delete(self, parent: Optional[Node], deleted: Node) -> None:
419 """This is called just after deleted was deleted from the tree"""
423 self, value: ComparableNodeValue, parent: Optional[Node], node: Node
426 if node.value == value:
428 # Deleting a leaf node
429 if node.left is None and node.right is None:
430 if parent is not None:
431 if parent.left == node:
434 assert parent.right == node
436 self._on_delete(parent, node)
439 # Node only has a right.
440 elif node.left is None:
441 assert node.right is not None
442 if parent is not None:
443 if parent.left == node:
444 parent.left = node.right
446 assert parent.right == node
447 parent.right = node.right
448 self._on_delete(parent, node)
451 # Node only has a left.
452 elif node.right is None:
453 assert node.left is not None
454 if parent is not None:
455 if parent.left == node:
456 parent.left = node.left
458 assert parent.right == node
459 parent.right = node.left
460 self._on_delete(parent, node)
463 # Node has both a left and right.
465 assert node.left is not None and node.right is not None
466 descendent = node.right
467 while descendent.left is not None:
468 descendent = descendent.left
469 node.value = descendent.value
470 return self._delete(node.value, node, node.right)
471 elif value < node.value and node.left is not None:
472 return self._delete(value, node, node.left)
473 elif value > node.value and node.right is not None:
474 return self._delete(value, node, node.right)
480 The count of items in the tree.
482 >>> t = BinarySearchTree()
488 >>> t.__delitem__(50)
504 def __contains__(self, value: ComparableNodeValue) -> bool:
507 True if the item is in the tree; False otherwise.
509 return self.__getitem__(value) is not None
511 def _iterate_preorder(self, node: Node):
513 if node.left is not None:
514 yield from self._iterate_preorder(node.left)
515 if node.right is not None:
516 yield from self._iterate_preorder(node.right)
518 def _iterate_inorder(self, node: Node):
519 if node.left is not None:
520 yield from self._iterate_inorder(node.left)
522 if node.right is not None:
523 yield from self._iterate_inorder(node.right)
525 def _iterate_postorder(self, node: Node):
526 if node.left is not None:
527 yield from self._iterate_postorder(node.left)
528 if node.right is not None:
529 yield from self._iterate_postorder(node.right)
532 def iterate_preorder(self):
535 A Generator that yields the tree's items in a
536 preorder traversal sequence.
538 >>> t = BinarySearchTree()
546 >>> for value in t.iterate_preorder():
556 if self.root is not None:
557 yield from self._iterate_preorder(self.root)
559 def iterate_inorder(self):
562 A Generator that yield the tree's items in a preorder
565 >>> t = BinarySearchTree()
582 >>> for value in t.iterate_inorder():
593 if self.root is not None:
594 yield from self._iterate_inorder(self.root)
596 def iterate_postorder(self):
599 A Generator that yield the tree's items in a preorder
602 >>> t = BinarySearchTree()
610 >>> for value in t.iterate_postorder():
620 if self.root is not None:
621 yield from self._iterate_postorder(self.root)
623 def _iterate_leaves(self, node: Node):
624 if node.left is not None:
625 yield from self._iterate_leaves(node.left)
626 if node.right is not None:
627 yield from self._iterate_leaves(node.right)
628 if node.left is None and node.right is None:
631 def iterate_leaves(self):
634 A Gemerator that yielde only the leaf nodes in the
637 >>> t = BinarySearchTree()
645 >>> for value in t.iterate_leaves():
651 if self.root is not None:
652 yield from self._iterate_leaves(self.root)
654 def _iterate_by_depth(self, node: Node, depth: int):
659 if node.left is not None:
660 yield from self._iterate_by_depth(node.left, depth - 1)
661 if node.right is not None:
662 yield from self._iterate_by_depth(node.right, depth - 1)
664 def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
667 depth: the desired depth
670 A Generator that yields nodes at the prescribed depth in
673 >>> t = BinarySearchTree()
681 >>> for value in t.iterate_nodes_by_depth(2):
686 >>> for value in t.iterate_nodes_by_depth(3):
691 if self.root is not None:
692 yield from self._iterate_by_depth(self.root, depth)
694 def get_next_node(self, node: Node) -> Optional[Node]:
697 node: the node whose next greater successor is desired
700 Given a tree node, returns the next greater node in the tree.
701 If the given node is the greatest node in the tree, returns None.
703 >>> t = BinarySearchTree()
721 >>> t.get_next_node(n).value
725 >>> t.get_next_node(n).value
729 >>> t.get_next_node(n) is None
733 if node.right is not None:
735 while x.left is not None:
739 path = self.parent_path(node)
740 assert path[-1] is not None
741 assert path[-1] == node
744 for ancestor in path:
745 assert ancestor is not None
746 if node != ancestor.right:
751 def get_nodes_in_range_inclusive(
752 self, lower: ComparableNodeValue, upper: ComparableNodeValue
755 >>> t = BinarySearchTree()
772 >>> for node in t.get_nodes_in_range_inclusive(21, 74):
773 ... print(node.value)
780 node: Optional[Node] = self._find_lowest_node_greater_than_or_equal_to(
784 if lower <= node.value <= upper:
786 node = self.get_next_node(node)
788 def _depth(self, node: Node, sofar: int) -> int:
789 depth_left = sofar + 1
790 depth_right = sofar + 1
791 if node.left is not None:
792 depth_left = self._depth(node.left, sofar + 1)
793 if node.right is not None:
794 depth_right = self._depth(node.right, sofar + 1)
795 return max(depth_left, depth_right)
797 def depth(self) -> int:
800 The max height (depth) of the tree in plies (edge distance
803 >>> t = BinarySearchTree()
825 if self.root is None:
827 return self._depth(self.root, 0)
829 def height(self) -> int:
830 """Returns the height (i.e. max depth) of the tree"""
837 node: Optional[Node],
838 has_right_sibling: bool,
841 viz = f"\n{padding}{pointer}{node.value}"
842 if has_right_sibling:
847 pointer_right = "└──"
848 if node.right is not None:
853 viz += self.repr_traverse(
854 padding, pointer_left, node.left, node.right is not None
856 viz += self.repr_traverse(padding, pointer_right, node.right, False)
863 An ASCII string representation of the tree.
865 >>> t = BinarySearchTree()
882 if self.root is None:
885 ret = f"{self.root.value}"
886 pointer_right = "└──"
887 if self.root.right is None:
892 ret += self.repr_traverse(
893 "", pointer_left, self.root.left, self.root.left is not None
895 ret += self.repr_traverse("", pointer_right, self.root.right, False)
899 if __name__ == "__main__":