3 # © Copyright 2021-2022, Scott Gasch
5 """A binary search tree."""
7 from typing import Any, Generator, List, Optional
11 def __init__(self, value: Any) -> None:
12 """Note that value can be anything as long as it is
13 comparable. Check out @functools.total_ordering.
16 self.left: Optional[Node] = None
17 self.right: Optional[Node] = None
21 class BinarySearchTree(object):
27 def get_root(self) -> Optional[Node]:
28 """:returns the root of the BST."""
32 def insert(self, value: Any):
34 Insert something into the tree.
36 >>> t = BinarySearchTree()
43 >>> t.get_root().value
48 self.root = Node(value)
51 self._insert(value, self.root)
53 def _insert(self, value: Any, node: Node):
54 """Insertion helper"""
55 if value < node.value:
56 if node.left is not None:
57 self._insert(value, node.left)
59 node.left = Node(value)
62 if node.right is not None:
63 self._insert(value, node.right)
65 node.right = Node(value)
68 def __getitem__(self, value: Any) -> Optional[Node]:
70 Find an item in the tree and return its Node. Returns
71 None if the item is not in the tree.
73 >>> t = BinarySearchTree()
85 if self.root is not None:
86 return self._find(value, self.root)
89 def _find(self, value: Any, node: Node) -> Optional[Node]:
91 if value == node.value:
93 elif value < node.value and node.left is not None:
94 return self._find(value, node.left)
95 elif value > node.value and node.right is not None:
96 return self._find(value, node.right)
100 self, current: Optional[Node], target: Node
101 ) -> List[Optional[Node]]:
104 ret: List[Optional[Node]] = [current]
105 if target.value == current.value:
107 elif target.value < current.value:
108 ret.extend(self._parent_path(current.left, target))
111 assert target.value > current.value
112 ret.extend(self._parent_path(current.right, target))
115 def parent_path(self, node: Node) -> List[Optional[Node]]:
116 """Return a list of nodes representing the path from
117 the tree's root to the node argument. If the node does
118 not exist in the tree for some reason, the last element
119 on the path will be None but the path will indicate the
120 ancestor path of that node were it inserted.
122 >>> t = BinarySearchTree()
140 >>> for x in t.parent_path(n):
148 >>> for x in t.parent_path(n):
149 ... if x is not None:
159 return self._parent_path(self.root, node)
161 def __delitem__(self, value: Any) -> bool:
163 Delete an item from the tree and preserve the BST property.
165 >>> t = BinarySearchTree()
182 >>> for value in t.iterate_inorder():
192 >>> del t[22] # Note: bool result is discarded
194 >>> for value in t.iterate_inorder():
203 >>> t.__delitem__(13)
205 >>> for value in t.iterate_inorder():
213 >>> t.__delitem__(75)
215 >>> for value in t.iterate_inorder():
227 >>> t.__delitem__(99)
231 if self.root is not None:
232 ret = self._delete(value, None, self.root)
240 def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool:
242 if node.value == value:
243 # Deleting a leaf node
244 if node.left is None and node.right is None:
245 if parent is not None:
246 if parent.left == node:
249 assert parent.right == node
253 # Node only has a right.
254 elif node.left is None:
255 assert node.right is not None
256 if parent is not None:
257 if parent.left == node:
258 parent.left = node.right
260 assert parent.right == node
261 parent.right = node.right
264 # Node only has a left.
265 elif node.right is None:
266 assert node.left is not None
267 if parent is not None:
268 if parent.left == node:
269 parent.left = node.left
271 assert parent.right == node
272 parent.right = node.left
275 # Node has both a left and right.
277 assert node.left is not None and node.right is not None
278 descendent = node.right
279 while descendent.left is not None:
280 descendent = descendent.left
281 node.value = descendent.value
282 return self._delete(node.value, node, node.right)
283 elif value < node.value and node.left is not None:
284 return self._delete(value, node, node.left)
285 elif value > node.value and node.right is not None:
286 return self._delete(value, node, node.right)
291 Returns the count of items in the tree.
293 >>> t = BinarySearchTree()
299 >>> t.__delitem__(50)
315 def __contains__(self, value: Any) -> bool:
317 Returns True if the item is in the tree; False otherwise.
319 return self.__getitem__(value) is not None
321 def _iterate_preorder(self, node: Node):
323 if node.left is not None:
324 yield from self._iterate_preorder(node.left)
325 if node.right is not None:
326 yield from self._iterate_preorder(node.right)
328 def _iterate_inorder(self, node: Node):
329 if node.left is not None:
330 yield from self._iterate_inorder(node.left)
332 if node.right is not None:
333 yield from self._iterate_inorder(node.right)
335 def _iterate_postorder(self, node: Node):
336 if node.left is not None:
337 yield from self._iterate_postorder(node.left)
338 if node.right is not None:
339 yield from self._iterate_postorder(node.right)
342 def iterate_preorder(self):
344 Yield the tree's items in a preorder traversal sequence.
346 >>> t = BinarySearchTree()
354 >>> for value in t.iterate_preorder():
364 if self.root is not None:
365 yield from self._iterate_preorder(self.root)
367 def iterate_inorder(self):
369 Yield the tree's items in a preorder traversal sequence.
371 >>> t = BinarySearchTree()
388 >>> for value in t.iterate_inorder():
399 if self.root is not None:
400 yield from self._iterate_inorder(self.root)
402 def iterate_postorder(self):
404 Yield the tree's items in a preorder traversal sequence.
406 >>> t = BinarySearchTree()
414 >>> for value in t.iterate_postorder():
424 if self.root is not None:
425 yield from self._iterate_postorder(self.root)
427 def _iterate_leaves(self, node: Node):
428 if node.left is not None:
429 yield from self._iterate_leaves(node.left)
430 if node.right is not None:
431 yield from self._iterate_leaves(node.right)
432 if node.left is None and node.right is None:
435 def iterate_leaves(self):
437 Iterate only the leaf nodes in the tree.
439 >>> t = BinarySearchTree()
447 >>> for value in t.iterate_leaves():
453 if self.root is not None:
454 yield from self._iterate_leaves(self.root)
456 def _iterate_by_depth(self, node: Node, depth: int):
461 if node.left is not None:
462 yield from self._iterate_by_depth(node.left, depth - 1)
463 if node.right is not None:
464 yield from self._iterate_by_depth(node.right, depth - 1)
466 def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
468 Iterate only the leaf nodes in the tree.
470 >>> t = BinarySearchTree()
478 >>> for value in t.iterate_nodes_by_depth(2):
483 >>> for value in t.iterate_nodes_by_depth(3):
488 if self.root is not None:
489 yield from self._iterate_by_depth(self.root, depth)
491 def get_next_node(self, node: Node) -> Node:
493 Given a tree node, get the next greater node in the tree.
495 >>> t = BinarySearchTree()
513 >>> t.get_next_node(n).value
517 >>> t.get_next_node(n).value
521 if node.right is not None:
523 while x.left is not None:
527 path = self.parent_path(node)
528 assert path[-1] is not None
529 assert path[-1] == node
532 for ancestor in path:
533 assert ancestor is not None
534 if node != ancestor.right:
539 def _depth(self, node: Node, sofar: int) -> int:
540 depth_left = sofar + 1
541 depth_right = sofar + 1
542 if node.left is not None:
543 depth_left = self._depth(node.left, sofar + 1)
544 if node.right is not None:
545 depth_right = self._depth(node.right, sofar + 1)
546 return max(depth_left, depth_right)
548 def depth(self) -> int:
550 Returns the max height (depth) of the tree in plies (edge distance
553 >>> t = BinarySearchTree()
575 if self.root is None:
577 return self._depth(self.root, 0)
579 def height(self) -> int:
580 """Returns the height (i.e. max depth) of the tree"""
587 node: Optional[Node],
588 has_right_sibling: bool,
591 viz = f'\n{padding}{pointer}{node.value}'
592 if has_right_sibling:
597 pointer_right = "└──"
598 if node.right is not None:
603 viz += self.repr_traverse(
604 padding, pointer_left, node.left, node.right is not None
606 viz += self.repr_traverse(padding, pointer_right, node.right, False)
612 Draw the tree in ASCII.
614 >>> t = BinarySearchTree()
631 if self.root is None:
634 ret = f'{self.root.value}'
635 pointer_right = "└──"
636 if self.root.right is None:
641 ret += self.repr_traverse(
642 '', pointer_left, self.root.left, self.root.left is not None
644 ret += self.repr_traverse('', pointer_right, self.root.right, False)