3 # © Copyright 2021-2023, Scott Gasch
5 """A binary search tree implementation."""
7 from abc import ABCMeta, abstractmethod
8 from typing import Any, Generator, List, Optional, TypeVar
11 class Comparable(metaclass=ABCMeta):
13 def __lt__(self, other: Any) -> bool:
17 def __le__(self, other: Any) -> bool:
21 def __eq__(self, other: Any) -> bool:
25 ComparableNodeValue = TypeVar('ComparableNodeValue', bound=Comparable)
29 def __init__(self, value: ComparableNodeValue) -> None:
30 """A BST node. Note that value can be anything as long as it
31 is comparable with other instances of itself. Check out
32 :meth:`functools.total_ordering`
33 (https://docs.python.org/3/library/functools.html#functools.total_ordering)
36 value: a reference to the value of the node.
39 self.left: Optional[Node] = None
40 self.right: Optional[Node] = None
41 self.value: ComparableNodeValue = value
44 class BinarySearchTree(object):
50 def get_root(self) -> Optional[Node]:
58 def _on_insert(self, parent: Optional[Node], new: Node) -> None:
59 """This is called immediately _after_ a new node is inserted."""
62 def insert(self, value: ComparableNodeValue) -> None:
64 Insert something into the tree.
67 value: the value to be inserted.
69 >>> t = BinarySearchTree()
76 >>> t.get_root().value
81 self.root = Node(value)
83 self._on_insert(None, self.root)
85 self._insert(value, self.root)
87 def _insert(self, value: ComparableNodeValue, node: Node):
88 """Insertion helper"""
89 if value < node.value:
90 if node.left is not None:
91 self._insert(value, node.left)
93 node.left = Node(value)
95 self._on_insert(node, node.left)
97 if node.right is not None:
98 self._insert(value, node.right)
100 node.right = Node(value)
102 self._on_insert(node, node.right)
104 def __getitem__(self, value: ComparableNodeValue) -> Optional[Node]:
106 Find an item in the tree and return its Node. Returns
107 None if the item is not in the tree.
109 >>> t = BinarySearchTree()
121 if self.root is not None:
122 return self._find(value, self.root)
125 def _find(self, value: ComparableNodeValue, node: Node) -> Optional[Node]:
127 if value == node.value:
129 elif value < node.value and node.left is not None:
130 return self._find(value, node.left)
131 elif value > node.value and node.right is not None:
132 return self._find(value, node.right)
135 def _find_lowest_value_greater_than_or_equal_to(
136 self, target: ComparableNodeValue, node: Optional[Node]
138 """Find helper that returns the lowest node that is greater
139 than or equal to the target value. Returns None if target is
140 greater than the highest node in the tree.
142 >>> t = BinarySearchTree()
159 >>> t._find_lowest_value_greater_than_or_equal_to(48, t.root).value
161 >>> t._find_lowest_value_greater_than_or_equal_to(55, t.root).value
163 >>> t._find_lowest_value_greater_than_or_equal_to(1, t.root).value
165 >>> t._find_lowest_value_greater_than_or_equal_to(24, t.root).value
167 >>> t._find_lowest_value_greater_than_or_equal_to(20, t.root).value
169 >>> t._find_lowest_value_greater_than_or_equal_to(72, t.root).value
171 >>> t._find_lowest_value_greater_than_or_equal_to(78, t.root).value
173 >>> t._find_lowest_value_greater_than_or_equal_to(95, t.root) is None
181 if target == node.value:
183 elif target >= node.value:
184 return self._find_lowest_value_greater_than_or_equal_to(target, node.right)
186 assert target < node.value
187 if below := self._find_lowest_value_greater_than_or_equal_to(
195 self, current: Optional[Node], target: Node
196 ) -> List[Optional[Node]]:
197 """Internal helper"""
200 ret: List[Optional[Node]] = [current]
201 if target.value == current.value:
203 elif target.value < current.value:
204 ret.extend(self._parent_path(current.left, target))
207 assert target.value > current.value
208 ret.extend(self._parent_path(current.right, target))
211 def parent_path(self, node: Node) -> List[Optional[Node]]:
212 """Get a node's parent path.
215 node: the node to check
218 a list of nodes representing the path from
219 the tree's root to the node.
223 If the node does not exist in the tree, the last element
224 on the path will be None but the path will indicate the
225 ancestor path of that node were it to be inserted.
227 >>> t = BinarySearchTree()
245 >>> for x in t.parent_path(n):
253 >>> for x in t.parent_path(n):
254 ... if x is not None:
264 return self._parent_path(self.root, node)
266 def __delitem__(self, value: ComparableNodeValue) -> bool:
268 Delete an item from the tree and preserve the BST property.
271 value: the value of the node to be deleted.
274 True if the value was found and its associated node was
275 successfully deleted and False otherwise.
277 >>> t = BinarySearchTree()
294 >>> for value in t.iterate_inorder():
304 >>> del t[22] # Note: bool result is discarded
306 >>> for value in t.iterate_inorder():
315 >>> t.__delitem__(13)
317 >>> for value in t.iterate_inorder():
325 >>> t.__delitem__(75)
327 >>> for value in t.iterate_inorder():
339 >>> t.__delitem__(99)
343 if self.root is not None:
344 ret = self._delete(value, None, self.root)
352 def _on_delete(self, parent: Optional[Node], deleted: Node) -> None:
353 """This is called just after deleted was deleted from the tree"""
357 self, value: ComparableNodeValue, parent: Optional[Node], node: Node
360 if node.value == value:
362 # Deleting a leaf node
363 if node.left is None and node.right is None:
364 if parent is not None:
365 if parent.left == node:
368 assert parent.right == node
370 self._on_delete(parent, node)
373 # Node only has a right.
374 elif node.left is None:
375 assert node.right is not None
376 if parent is not None:
377 if parent.left == node:
378 parent.left = node.right
380 assert parent.right == node
381 parent.right = node.right
382 self._on_delete(parent, node)
385 # Node only has a left.
386 elif node.right is None:
387 assert node.left is not None
388 if parent is not None:
389 if parent.left == node:
390 parent.left = node.left
392 assert parent.right == node
393 parent.right = node.left
394 self._on_delete(parent, node)
397 # Node has both a left and right.
399 assert node.left is not None and node.right is not None
400 descendent = node.right
401 while descendent.left is not None:
402 descendent = descendent.left
403 node.value = descendent.value
404 return self._delete(node.value, node, node.right)
405 elif value < node.value and node.left is not None:
406 return self._delete(value, node, node.left)
407 elif value > node.value and node.right is not None:
408 return self._delete(value, node, node.right)
414 The count of items in the tree.
416 >>> t = BinarySearchTree()
422 >>> t.__delitem__(50)
438 def __contains__(self, value: ComparableNodeValue) -> bool:
441 True if the item is in the tree; False otherwise.
443 return self.__getitem__(value) is not None
445 def _iterate_preorder(self, node: Node):
447 if node.left is not None:
448 yield from self._iterate_preorder(node.left)
449 if node.right is not None:
450 yield from self._iterate_preorder(node.right)
452 def _iterate_inorder(self, node: Node):
453 if node.left is not None:
454 yield from self._iterate_inorder(node.left)
456 if node.right is not None:
457 yield from self._iterate_inorder(node.right)
459 def _iterate_postorder(self, node: Node):
460 if node.left is not None:
461 yield from self._iterate_postorder(node.left)
462 if node.right is not None:
463 yield from self._iterate_postorder(node.right)
466 def iterate_preorder(self):
469 A Generator that yields the tree's items in a
470 preorder traversal sequence.
472 >>> t = BinarySearchTree()
480 >>> for value in t.iterate_preorder():
490 if self.root is not None:
491 yield from self._iterate_preorder(self.root)
493 def iterate_inorder(self):
496 A Generator that yield the tree's items in a preorder
499 >>> t = BinarySearchTree()
516 >>> for value in t.iterate_inorder():
527 if self.root is not None:
528 yield from self._iterate_inorder(self.root)
530 def iterate_postorder(self):
533 A Generator that yield the tree's items in a preorder
536 >>> t = BinarySearchTree()
544 >>> for value in t.iterate_postorder():
554 if self.root is not None:
555 yield from self._iterate_postorder(self.root)
557 def _iterate_leaves(self, node: Node):
558 if node.left is not None:
559 yield from self._iterate_leaves(node.left)
560 if node.right is not None:
561 yield from self._iterate_leaves(node.right)
562 if node.left is None and node.right is None:
565 def iterate_leaves(self):
568 A Gemerator that yielde only the leaf nodes in the
571 >>> t = BinarySearchTree()
579 >>> for value in t.iterate_leaves():
585 if self.root is not None:
586 yield from self._iterate_leaves(self.root)
588 def _iterate_by_depth(self, node: Node, depth: int):
593 if node.left is not None:
594 yield from self._iterate_by_depth(node.left, depth - 1)
595 if node.right is not None:
596 yield from self._iterate_by_depth(node.right, depth - 1)
598 def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
601 depth: the desired depth
604 A Generator that yields nodes at the prescribed depth in
607 >>> t = BinarySearchTree()
615 >>> for value in t.iterate_nodes_by_depth(2):
620 >>> for value in t.iterate_nodes_by_depth(3):
625 if self.root is not None:
626 yield from self._iterate_by_depth(self.root, depth)
628 def get_next_node(self, node: Node) -> Optional[Node]:
631 node: the node whose next greater successor is desired
634 Given a tree node, returns the next greater node in the tree.
635 If the given node is the greatest node in the tree, returns None.
637 >>> t = BinarySearchTree()
655 >>> t.get_next_node(n).value
659 >>> t.get_next_node(n).value
663 >>> t.get_next_node(n) is None
667 if node.right is not None:
669 while x.left is not None:
673 path = self.parent_path(node)
674 assert path[-1] is not None
675 assert path[-1] == node
678 for ancestor in path:
679 assert ancestor is not None
680 if node != ancestor.right:
685 def get_nodes_in_range_inclusive(
686 self, lower: ComparableNodeValue, upper: ComparableNodeValue
689 >>> t = BinarySearchTree()
706 >>> for node in t.get_nodes_in_range_inclusive(21, 74):
707 ... print(node.value)
714 node: Optional[Node] = self._find_lowest_value_greater_than_or_equal_to(
718 if lower <= node.value <= upper:
720 node = self.get_next_node(node)
722 def _depth(self, node: Node, sofar: int) -> int:
723 depth_left = sofar + 1
724 depth_right = sofar + 1
725 if node.left is not None:
726 depth_left = self._depth(node.left, sofar + 1)
727 if node.right is not None:
728 depth_right = self._depth(node.right, sofar + 1)
729 return max(depth_left, depth_right)
731 def depth(self) -> int:
734 The max height (depth) of the tree in plies (edge distance
737 >>> t = BinarySearchTree()
759 if self.root is None:
761 return self._depth(self.root, 0)
763 def height(self) -> int:
764 """Returns the height (i.e. max depth) of the tree"""
771 node: Optional[Node],
772 has_right_sibling: bool,
775 viz = f"\n{padding}{pointer}{node.value}"
776 if has_right_sibling:
781 pointer_right = "└──"
782 if node.right is not None:
787 viz += self.repr_traverse(
788 padding, pointer_left, node.left, node.right is not None
790 viz += self.repr_traverse(padding, pointer_right, node.right, False)
797 An ASCII string representation of the tree.
799 >>> t = BinarySearchTree()
816 if self.root is None:
819 ret = f"{self.root.value}"
820 pointer_right = "└──"
821 if self.root.right is None:
826 ret += self.repr_traverse(
827 "", pointer_left, self.root.left, self.root.left is not None
829 ret += self.repr_traverse("", pointer_right, self.root.right, False)
833 if __name__ == "__main__":