3 # © Copyright 2021-2023, Scott Gasch
5 """A binary search tree implementation."""
7 from typing import Any, Generator, List, Optional
11 def __init__(self, value: Any) -> None:
13 A BST node. Note that value can be anything as long as it
14 is comparable. Check out :meth:`functools.total_ordering`
15 (https://docs.python.org/3/library/functools.html#functools.total_ordering)
18 value: a reference to the value of the node.
20 self.left: Optional[Node] = None
21 self.right: Optional[Node] = None
25 class BinarySearchTree(object):
31 def get_root(self) -> Optional[Node]:
39 def _on_insert(self, parent: Optional[Node], new: Node) -> None:
40 """This is called immediately _after_ a new node is inserted."""
43 def insert(self, value: Any) -> None:
45 Insert something into the tree.
48 value: the value to be inserted.
50 >>> t = BinarySearchTree()
57 >>> t.get_root().value
62 self.root = Node(value)
64 self._on_insert(None, self.root)
66 self._insert(value, self.root)
68 def _insert(self, value: Any, node: Node):
69 """Insertion helper"""
70 if value < node.value:
71 if node.left is not None:
72 self._insert(value, node.left)
74 node.left = Node(value)
76 self._on_insert(node, node.left)
78 if node.right is not None:
79 self._insert(value, node.right)
81 node.right = Node(value)
83 self._on_insert(node, node.right)
85 def __getitem__(self, value: Any) -> Optional[Node]:
87 Find an item in the tree and return its Node. Returns
88 None if the item is not in the tree.
90 >>> t = BinarySearchTree()
102 if self.root is not None:
103 return self._find(value, self.root)
106 def _find(self, value: Any, node: Node) -> Optional[Node]:
108 if value == node.value:
110 elif value < node.value and node.left is not None:
111 return self._find(value, node.left)
112 elif value > node.value and node.right is not None:
113 return self._find(value, node.right)
116 def _find_lowest_value_greater_than_or_equal_to(
117 self, target: Any, node: Optional[Node]
119 """Find helper that returns the lowest node that is greater
120 than or equal to the target value. Returns None if target is
121 greater than the highest node in the tree.
123 >>> t = BinarySearchTree()
140 >>> t._find_lowest_value_greater_than_or_equal_to(48, t.root).value
142 >>> t._find_lowest_value_greater_than_or_equal_to(55, t.root).value
144 >>> t._find_lowest_value_greater_than_or_equal_to(1, t.root).value
146 >>> t._find_lowest_value_greater_than_or_equal_to(24, t.root).value
148 >>> t._find_lowest_value_greater_than_or_equal_to(20, t.root).value
150 >>> t._find_lowest_value_greater_than_or_equal_to(72, t.root).value
152 >>> t._find_lowest_value_greater_than_or_equal_to(78, t.root).value
154 >>> t._find_lowest_value_greater_than_or_equal_to(95, t.root) is None
162 if target == node.value:
164 elif target >= node.value:
165 return self._find_lowest_value_greater_than_or_equal_to(target, node.right)
167 assert target < node.value
168 if below := self._find_lowest_value_greater_than_or_equal_to(
176 self, current: Optional[Node], target: Node
177 ) -> List[Optional[Node]]:
178 """Internal helper"""
181 ret: List[Optional[Node]] = [current]
182 if target.value == current.value:
184 elif target.value < current.value:
185 ret.extend(self._parent_path(current.left, target))
188 assert target.value > current.value
189 ret.extend(self._parent_path(current.right, target))
192 def parent_path(self, node: Node) -> List[Optional[Node]]:
193 """Get a node's parent path.
196 node: the node to check
199 a list of nodes representing the path from
200 the tree's root to the node.
204 If the node does not exist in the tree, the last element
205 on the path will be None but the path will indicate the
206 ancestor path of that node were it to be inserted.
208 >>> t = BinarySearchTree()
226 >>> for x in t.parent_path(n):
234 >>> for x in t.parent_path(n):
235 ... if x is not None:
245 return self._parent_path(self.root, node)
247 def __delitem__(self, value: Any) -> bool:
249 Delete an item from the tree and preserve the BST property.
252 value: the value of the node to be deleted.
255 True if the value was found and its associated node was
256 successfully deleted and False otherwise.
258 >>> t = BinarySearchTree()
275 >>> for value in t.iterate_inorder():
285 >>> del t[22] # Note: bool result is discarded
287 >>> for value in t.iterate_inorder():
296 >>> t.__delitem__(13)
298 >>> for value in t.iterate_inorder():
306 >>> t.__delitem__(75)
308 >>> for value in t.iterate_inorder():
320 >>> t.__delitem__(99)
324 if self.root is not None:
325 ret = self._delete(value, None, self.root)
333 def _on_delete(self, parent: Optional[Node], deleted: Node) -> None:
334 """This is called just after deleted was deleted from the tree"""
337 def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool:
339 if node.value == value:
341 # Deleting a leaf node
342 if node.left is None and node.right is None:
343 if parent is not None:
344 if parent.left == node:
347 assert parent.right == node
349 self._on_delete(parent, node)
352 # Node only has a right.
353 elif node.left is None:
354 assert node.right is not None
355 if parent is not None:
356 if parent.left == node:
357 parent.left = node.right
359 assert parent.right == node
360 parent.right = node.right
361 self._on_delete(parent, node)
364 # Node only has a left.
365 elif node.right is None:
366 assert node.left is not None
367 if parent is not None:
368 if parent.left == node:
369 parent.left = node.left
371 assert parent.right == node
372 parent.right = node.left
373 self._on_delete(parent, node)
376 # Node has both a left and right.
378 assert node.left is not None and node.right is not None
379 descendent = node.right
380 while descendent.left is not None:
381 descendent = descendent.left
382 node.value = descendent.value
383 return self._delete(node.value, node, node.right)
384 elif value < node.value and node.left is not None:
385 return self._delete(value, node, node.left)
386 elif value > node.value and node.right is not None:
387 return self._delete(value, node, node.right)
393 The count of items in the tree.
395 >>> t = BinarySearchTree()
401 >>> t.__delitem__(50)
417 def __contains__(self, value: Any) -> bool:
420 True if the item is in the tree; False otherwise.
422 return self.__getitem__(value) is not None
424 def _iterate_preorder(self, node: Node):
426 if node.left is not None:
427 yield from self._iterate_preorder(node.left)
428 if node.right is not None:
429 yield from self._iterate_preorder(node.right)
431 def _iterate_inorder(self, node: Node):
432 if node.left is not None:
433 yield from self._iterate_inorder(node.left)
435 if node.right is not None:
436 yield from self._iterate_inorder(node.right)
438 def _iterate_postorder(self, node: Node):
439 if node.left is not None:
440 yield from self._iterate_postorder(node.left)
441 if node.right is not None:
442 yield from self._iterate_postorder(node.right)
445 def iterate_preorder(self):
448 A Generator that yields the tree's items in a
449 preorder traversal sequence.
451 >>> t = BinarySearchTree()
459 >>> for value in t.iterate_preorder():
469 if self.root is not None:
470 yield from self._iterate_preorder(self.root)
472 def iterate_inorder(self):
475 A Generator that yield the tree's items in a preorder
478 >>> t = BinarySearchTree()
495 >>> for value in t.iterate_inorder():
506 if self.root is not None:
507 yield from self._iterate_inorder(self.root)
509 def iterate_postorder(self):
512 A Generator that yield the tree's items in a preorder
515 >>> t = BinarySearchTree()
523 >>> for value in t.iterate_postorder():
533 if self.root is not None:
534 yield from self._iterate_postorder(self.root)
536 def _iterate_leaves(self, node: Node):
537 if node.left is not None:
538 yield from self._iterate_leaves(node.left)
539 if node.right is not None:
540 yield from self._iterate_leaves(node.right)
541 if node.left is None and node.right is None:
544 def iterate_leaves(self):
547 A Gemerator that yielde only the leaf nodes in the
550 >>> t = BinarySearchTree()
558 >>> for value in t.iterate_leaves():
564 if self.root is not None:
565 yield from self._iterate_leaves(self.root)
567 def _iterate_by_depth(self, node: Node, depth: int):
572 if node.left is not None:
573 yield from self._iterate_by_depth(node.left, depth - 1)
574 if node.right is not None:
575 yield from self._iterate_by_depth(node.right, depth - 1)
577 def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
580 depth: the desired depth
583 A Generator that yields nodes at the prescribed depth in
586 >>> t = BinarySearchTree()
594 >>> for value in t.iterate_nodes_by_depth(2):
599 >>> for value in t.iterate_nodes_by_depth(3):
604 if self.root is not None:
605 yield from self._iterate_by_depth(self.root, depth)
607 def get_next_node(self, node: Node) -> Optional[Node]:
610 node: the node whose next greater successor is desired
613 Given a tree node, returns the next greater node in the tree.
614 If the given node is the greatest node in the tree, returns None.
616 >>> t = BinarySearchTree()
634 >>> t.get_next_node(n).value
638 >>> t.get_next_node(n).value
642 >>> t.get_next_node(n) is None
646 if node.right is not None:
648 while x.left is not None:
652 path = self.parent_path(node)
653 assert path[-1] is not None
654 assert path[-1] == node
657 for ancestor in path:
658 assert ancestor is not None
659 if node != ancestor.right:
664 def get_nodes_in_range_inclusive(self, lower: Any, upper: Any):
666 >>> t = BinarySearchTree()
683 >>> for node in t.get_nodes_in_range_inclusive(21, 74):
684 ... print(node.value)
691 node: Optional[Node] = self._find_lowest_value_greater_than_or_equal_to(
695 if lower <= node.value <= upper:
697 node = self.get_next_node(node)
699 def _depth(self, node: Node, sofar: int) -> int:
700 depth_left = sofar + 1
701 depth_right = sofar + 1
702 if node.left is not None:
703 depth_left = self._depth(node.left, sofar + 1)
704 if node.right is not None:
705 depth_right = self._depth(node.right, sofar + 1)
706 return max(depth_left, depth_right)
708 def depth(self) -> int:
711 The max height (depth) of the tree in plies (edge distance
714 >>> t = BinarySearchTree()
736 if self.root is None:
738 return self._depth(self.root, 0)
740 def height(self) -> int:
741 """Returns the height (i.e. max depth) of the tree"""
748 node: Optional[Node],
749 has_right_sibling: bool,
752 viz = f"\n{padding}{pointer}{node.value}"
753 if has_right_sibling:
758 pointer_right = "└──"
759 if node.right is not None:
764 viz += self.repr_traverse(
765 padding, pointer_left, node.left, node.right is not None
767 viz += self.repr_traverse(padding, pointer_right, node.right, False)
774 An ASCII string representation of the tree.
776 >>> t = BinarySearchTree()
793 if self.root is None:
796 ret = f"{self.root.value}"
797 pointer_right = "└──"
798 if self.root.right is None:
803 ret += self.repr_traverse(
804 "", pointer_left, self.root.left, self.root.left is not None
806 ret += self.repr_traverse("", pointer_right, self.root.right, False)
810 if __name__ == "__main__":