3 # © Copyright 2021-2023, Scott Gasch
5 """A binary search tree implementation."""
7 from abc import ABCMeta, abstractmethod
8 from typing import Any, Generator, List, Optional, TypeVar
11 class Comparable(metaclass=ABCMeta):
13 def __lt__(self, other: Any) -> bool:
17 def __le__(self, other: Any) -> bool:
21 def __eq__(self, other: Any) -> bool:
25 ComparableNodeValue = TypeVar('ComparableNodeValue', bound=Comparable)
29 def __init__(self, value: ComparableNodeValue) -> None:
30 """A BST node. Note that value can be anything as long as it
31 is comparable with other instances of itself. Check out
32 :meth:`functools.total_ordering`
33 (https://docs.python.org/3/library/functools.html#functools.total_ordering)
36 value: a reference to the value of the node.
39 self.left: Optional[Node] = None
40 self.right: Optional[Node] = None
41 self.value: ComparableNodeValue = value
44 class BinarySearchTree(object):
50 def get_root(self) -> Optional[Node]:
58 def _on_insert(self, parent: Optional[Node], new: Node) -> None:
59 """This is called immediately _after_ a new node is inserted."""
62 def insert(self, value: ComparableNodeValue) -> None:
64 Insert something into the tree.
67 value: the value to be inserted.
69 >>> t = BinarySearchTree()
76 >>> t.get_root().value
81 self.root = Node(value)
83 self._on_insert(None, self.root)
85 self._insert(value, self.root)
87 def _insert(self, value: ComparableNodeValue, node: Node):
88 """Insertion helper"""
89 if value < node.value:
90 if node.left is not None:
91 self._insert(value, node.left)
93 node.left = Node(value)
95 self._on_insert(node, node.left)
97 if node.right is not None:
98 self._insert(value, node.right)
100 node.right = Node(value)
102 self._on_insert(node, node.right)
104 def __getitem__(self, value: ComparableNodeValue) -> Optional[Node]:
106 Find an item in the tree and return its Node. Returns
107 None if the item is not in the tree.
109 >>> t = BinarySearchTree()
121 if self.root is not None:
122 return self._find(value, self.root)
125 def _find(self, value: ComparableNodeValue, node: Node) -> Optional[Node]:
127 if value == node.value:
129 elif value < node.value and node.left is not None:
130 return self._find(value, node.left)
131 elif value > node.value and node.right is not None:
132 return self._find(value, node.right)
135 def _find_lowest_value_greater_than_or_equal_to(
136 self, target: ComparableNodeValue, node: Optional[Node]
138 """Find helper that returns the lowest node that is greater
139 than or equal to the target value. Returns None if target is
140 greater than the highest node in the tree.
142 >>> t = BinarySearchTree()
159 >>> t._find_lowest_value_greater_than_or_equal_to(48, t.root).value
161 >>> t._find_lowest_value_greater_than_or_equal_to(55, t.root).value
163 >>> t._find_lowest_value_greater_than_or_equal_to(1, t.root).value
165 >>> t._find_lowest_value_greater_than_or_equal_to(24, t.root).value
167 >>> t._find_lowest_value_greater_than_or_equal_to(20, t.root).value
169 >>> t._find_lowest_value_greater_than_or_equal_to(72, t.root).value
171 >>> t._find_lowest_value_greater_than_or_equal_to(78, t.root).value
173 >>> t._find_lowest_value_greater_than_or_equal_to(95, t.root) is None
181 if target == node.value:
184 elif target > node.value:
185 return self._find_lowest_value_greater_than_or_equal_to(target, node.right)
187 # If target < this node's value, either this node is the
188 # answer or the answer is in this node's left subtree.
190 if below := self._find_lowest_value_greater_than_or_equal_to(
198 self, current: Optional[Node], target: Node
199 ) -> List[Optional[Node]]:
200 """Internal helper"""
203 ret: List[Optional[Node]] = [current]
204 if target.value == current.value:
206 elif target.value < current.value:
207 ret.extend(self._parent_path(current.left, target))
210 assert target.value > current.value
211 ret.extend(self._parent_path(current.right, target))
214 def parent_path(self, node: Node) -> List[Optional[Node]]:
215 """Get a node's parent path.
218 node: the node to check
221 a list of nodes representing the path from
222 the tree's root to the node.
226 If the node does not exist in the tree, the last element
227 on the path will be None but the path will indicate the
228 ancestor path of that node were it to be inserted.
230 >>> t = BinarySearchTree()
248 >>> for x in t.parent_path(n):
256 >>> for x in t.parent_path(n):
257 ... if x is not None:
267 return self._parent_path(self.root, node)
269 def __delitem__(self, value: ComparableNodeValue) -> bool:
271 Delete an item from the tree and preserve the BST property.
274 value: the value of the node to be deleted.
277 True if the value was found and its associated node was
278 successfully deleted and False otherwise.
280 >>> t = BinarySearchTree()
297 >>> for value in t.iterate_inorder():
307 >>> del t[22] # Note: bool result is discarded
309 >>> for value in t.iterate_inorder():
318 >>> t.__delitem__(13)
320 >>> for value in t.iterate_inorder():
328 >>> t.__delitem__(75)
330 >>> for value in t.iterate_inorder():
342 >>> t.__delitem__(99)
346 if self.root is not None:
347 ret = self._delete(value, None, self.root)
355 def _on_delete(self, parent: Optional[Node], deleted: Node) -> None:
356 """This is called just after deleted was deleted from the tree"""
360 self, value: ComparableNodeValue, parent: Optional[Node], node: Node
363 if node.value == value:
365 # Deleting a leaf node
366 if node.left is None and node.right is None:
367 if parent is not None:
368 if parent.left == node:
371 assert parent.right == node
373 self._on_delete(parent, node)
376 # Node only has a right.
377 elif node.left is None:
378 assert node.right is not None
379 if parent is not None:
380 if parent.left == node:
381 parent.left = node.right
383 assert parent.right == node
384 parent.right = node.right
385 self._on_delete(parent, node)
388 # Node only has a left.
389 elif node.right is None:
390 assert node.left is not None
391 if parent is not None:
392 if parent.left == node:
393 parent.left = node.left
395 assert parent.right == node
396 parent.right = node.left
397 self._on_delete(parent, node)
400 # Node has both a left and right.
402 assert node.left is not None and node.right is not None
403 descendent = node.right
404 while descendent.left is not None:
405 descendent = descendent.left
406 node.value = descendent.value
407 return self._delete(node.value, node, node.right)
408 elif value < node.value and node.left is not None:
409 return self._delete(value, node, node.left)
410 elif value > node.value and node.right is not None:
411 return self._delete(value, node, node.right)
417 The count of items in the tree.
419 >>> t = BinarySearchTree()
425 >>> t.__delitem__(50)
441 def __contains__(self, value: ComparableNodeValue) -> bool:
444 True if the item is in the tree; False otherwise.
446 return self.__getitem__(value) is not None
448 def _iterate_preorder(self, node: Node):
450 if node.left is not None:
451 yield from self._iterate_preorder(node.left)
452 if node.right is not None:
453 yield from self._iterate_preorder(node.right)
455 def _iterate_inorder(self, node: Node):
456 if node.left is not None:
457 yield from self._iterate_inorder(node.left)
459 if node.right is not None:
460 yield from self._iterate_inorder(node.right)
462 def _iterate_postorder(self, node: Node):
463 if node.left is not None:
464 yield from self._iterate_postorder(node.left)
465 if node.right is not None:
466 yield from self._iterate_postorder(node.right)
469 def iterate_preorder(self):
472 A Generator that yields the tree's items in a
473 preorder traversal sequence.
475 >>> t = BinarySearchTree()
483 >>> for value in t.iterate_preorder():
493 if self.root is not None:
494 yield from self._iterate_preorder(self.root)
496 def iterate_inorder(self):
499 A Generator that yield the tree's items in a preorder
502 >>> t = BinarySearchTree()
519 >>> for value in t.iterate_inorder():
530 if self.root is not None:
531 yield from self._iterate_inorder(self.root)
533 def iterate_postorder(self):
536 A Generator that yield the tree's items in a preorder
539 >>> t = BinarySearchTree()
547 >>> for value in t.iterate_postorder():
557 if self.root is not None:
558 yield from self._iterate_postorder(self.root)
560 def _iterate_leaves(self, node: Node):
561 if node.left is not None:
562 yield from self._iterate_leaves(node.left)
563 if node.right is not None:
564 yield from self._iterate_leaves(node.right)
565 if node.left is None and node.right is None:
568 def iterate_leaves(self):
571 A Gemerator that yielde only the leaf nodes in the
574 >>> t = BinarySearchTree()
582 >>> for value in t.iterate_leaves():
588 if self.root is not None:
589 yield from self._iterate_leaves(self.root)
591 def _iterate_by_depth(self, node: Node, depth: int):
596 if node.left is not None:
597 yield from self._iterate_by_depth(node.left, depth - 1)
598 if node.right is not None:
599 yield from self._iterate_by_depth(node.right, depth - 1)
601 def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
604 depth: the desired depth
607 A Generator that yields nodes at the prescribed depth in
610 >>> t = BinarySearchTree()
618 >>> for value in t.iterate_nodes_by_depth(2):
623 >>> for value in t.iterate_nodes_by_depth(3):
628 if self.root is not None:
629 yield from self._iterate_by_depth(self.root, depth)
631 def get_next_node(self, node: Node) -> Optional[Node]:
634 node: the node whose next greater successor is desired
637 Given a tree node, returns the next greater node in the tree.
638 If the given node is the greatest node in the tree, returns None.
640 >>> t = BinarySearchTree()
658 >>> t.get_next_node(n).value
662 >>> t.get_next_node(n).value
666 >>> t.get_next_node(n) is None
670 if node.right is not None:
672 while x.left is not None:
676 path = self.parent_path(node)
677 assert path[-1] is not None
678 assert path[-1] == node
681 for ancestor in path:
682 assert ancestor is not None
683 if node != ancestor.right:
688 def get_nodes_in_range_inclusive(
689 self, lower: ComparableNodeValue, upper: ComparableNodeValue
692 >>> t = BinarySearchTree()
709 >>> for node in t.get_nodes_in_range_inclusive(21, 74):
710 ... print(node.value)
717 node: Optional[Node] = self._find_lowest_value_greater_than_or_equal_to(
721 if lower <= node.value <= upper:
723 node = self.get_next_node(node)
725 def _depth(self, node: Node, sofar: int) -> int:
726 depth_left = sofar + 1
727 depth_right = sofar + 1
728 if node.left is not None:
729 depth_left = self._depth(node.left, sofar + 1)
730 if node.right is not None:
731 depth_right = self._depth(node.right, sofar + 1)
732 return max(depth_left, depth_right)
734 def depth(self) -> int:
737 The max height (depth) of the tree in plies (edge distance
740 >>> t = BinarySearchTree()
762 if self.root is None:
764 return self._depth(self.root, 0)
766 def height(self) -> int:
767 """Returns the height (i.e. max depth) of the tree"""
774 node: Optional[Node],
775 has_right_sibling: bool,
778 viz = f"\n{padding}{pointer}{node.value}"
779 if has_right_sibling:
784 pointer_right = "└──"
785 if node.right is not None:
790 viz += self.repr_traverse(
791 padding, pointer_left, node.left, node.right is not None
793 viz += self.repr_traverse(padding, pointer_right, node.right, False)
800 An ASCII string representation of the tree.
802 >>> t = BinarySearchTree()
819 if self.root is None:
822 ret = f"{self.root.value}"
823 pointer_right = "└──"
824 if self.root.right is None:
829 ret += self.repr_traverse(
830 "", pointer_left, self.root.left, self.root.left is not None
832 ret += self.repr_traverse("", pointer_right, self.root.right, False)
836 if __name__ == "__main__":