3 # © Copyright 2021-2022, Scott Gasch
5 """A binary search tree."""
7 from typing import Any, Generator, List, Optional
11 def __init__(self, value: Any) -> None:
13 Note: value can be anything as long as it is comparable.
14 Check out @functools.total_ordering.
16 self.left: Optional[Node] = None
17 self.right: Optional[Node] = None
21 class BinarySearchTree(object):
27 def get_root(self) -> Optional[Node]:
30 def insert(self, value: Any):
32 Insert something into the tree.
34 >>> t = BinarySearchTree()
41 >>> t.get_root().value
46 self.root = Node(value)
49 self._insert(value, self.root)
51 def _insert(self, value: Any, node: Node):
52 """Insertion helper"""
53 if value < node.value:
54 if node.left is not None:
55 self._insert(value, node.left)
57 node.left = Node(value)
60 if node.right is not None:
61 self._insert(value, node.right)
63 node.right = Node(value)
66 def __getitem__(self, value: Any) -> Optional[Node]:
68 Find an item in the tree and return its Node. Returns
69 None if the item is not in the tree.
71 >>> t = BinarySearchTree()
83 if self.root is not None:
84 return self._find(value, self.root)
87 def _find(self, value: Any, node: Node) -> Optional[Node]:
89 if value == node.value:
91 elif value < node.value and node.left is not None:
92 return self._find(value, node.left)
93 elif value > node.value and node.right is not None:
94 return self._find(value, node.right)
98 self, current: Optional[Node], target: Node
99 ) -> List[Optional[Node]]:
102 ret: List[Optional[Node]] = [current]
103 if target.value == current.value:
105 elif target.value < current.value:
106 ret.extend(self._parent_path(current.left, target))
109 assert target.value > current.value
110 ret.extend(self._parent_path(current.right, target))
113 def parent_path(self, node: Node) -> List[Optional[Node]]:
114 """Return a list of nodes representing the path from
115 the tree's root to the node argument. If the node does
116 not exist in the tree for some reason, the last element
117 on the path will be None but the path will indicate the
118 ancestor path of that node were it inserted.
120 >>> t = BinarySearchTree()
138 >>> for x in t.parent_path(n):
146 >>> for x in t.parent_path(n):
147 ... if x is not None:
157 return self._parent_path(self.root, node)
159 def __delitem__(self, value: Any) -> bool:
161 Delete an item from the tree and preserve the BST property.
163 >>> t = BinarySearchTree()
180 >>> for value in t.iterate_inorder():
190 >>> del t[22] # Note: bool result is discarded
192 >>> for value in t.iterate_inorder():
201 >>> t.__delitem__(13)
203 >>> for value in t.iterate_inorder():
211 >>> t.__delitem__(75)
213 >>> for value in t.iterate_inorder():
225 >>> t.__delitem__(99)
229 if self.root is not None:
230 ret = self._delete(value, None, self.root)
238 def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool:
240 if node.value == value:
241 # Deleting a leaf node
242 if node.left is None and node.right is None:
243 if parent is not None:
244 if parent.left == node:
247 assert parent.right == node
251 # Node only has a right.
252 elif node.left is None:
253 assert node.right is not None
254 if parent is not None:
255 if parent.left == node:
256 parent.left = node.right
258 assert parent.right == node
259 parent.right = node.right
262 # Node only has a left.
263 elif node.right is None:
264 assert node.left is not None
265 if parent is not None:
266 if parent.left == node:
267 parent.left = node.left
269 assert parent.right == node
270 parent.right = node.left
273 # Node has both a left and right.
275 assert node.left is not None and node.right is not None
276 descendent = node.right
277 while descendent.left is not None:
278 descendent = descendent.left
279 node.value = descendent.value
280 return self._delete(node.value, node, node.right)
281 elif value < node.value and node.left is not None:
282 return self._delete(value, node, node.left)
283 elif value > node.value and node.right is not None:
284 return self._delete(value, node, node.right)
289 Returns the count of items in the tree.
291 >>> t = BinarySearchTree()
297 >>> t.__delitem__(50)
313 def __contains__(self, value: Any) -> bool:
315 Returns True if the item is in the tree; False otherwise.
318 return self.__getitem__(value) is not None
320 def _iterate_preorder(self, node: Node):
322 if node.left is not None:
323 yield from self._iterate_preorder(node.left)
324 if node.right is not None:
325 yield from self._iterate_preorder(node.right)
327 def _iterate_inorder(self, node: Node):
328 if node.left is not None:
329 yield from self._iterate_inorder(node.left)
331 if node.right is not None:
332 yield from self._iterate_inorder(node.right)
334 def _iterate_postorder(self, node: Node):
335 if node.left is not None:
336 yield from self._iterate_postorder(node.left)
337 if node.right is not None:
338 yield from self._iterate_postorder(node.right)
341 def iterate_preorder(self):
343 Yield the tree's items in a preorder traversal sequence.
345 >>> t = BinarySearchTree()
353 >>> for value in t.iterate_preorder():
363 if self.root is not None:
364 yield from self._iterate_preorder(self.root)
366 def iterate_inorder(self):
368 Yield the tree's items in a preorder traversal sequence.
370 >>> t = BinarySearchTree()
387 >>> for value in t.iterate_inorder():
398 if self.root is not None:
399 yield from self._iterate_inorder(self.root)
401 def iterate_postorder(self):
403 Yield the tree's items in a preorder traversal sequence.
405 >>> t = BinarySearchTree()
413 >>> for value in t.iterate_postorder():
423 if self.root is not None:
424 yield from self._iterate_postorder(self.root)
426 def _iterate_leaves(self, node: Node):
427 if node.left is not None:
428 yield from self._iterate_leaves(node.left)
429 if node.right is not None:
430 yield from self._iterate_leaves(node.right)
431 if node.left is None and node.right is None:
434 def iterate_leaves(self):
436 Iterate only the leaf nodes in the tree.
438 >>> t = BinarySearchTree()
446 >>> for value in t.iterate_leaves():
452 if self.root is not None:
453 yield from self._iterate_leaves(self.root)
455 def _iterate_by_depth(self, node: Node, depth: int):
460 if node.left is not None:
461 yield from self._iterate_by_depth(node.left, depth - 1)
462 if node.right is not None:
463 yield from self._iterate_by_depth(node.right, depth - 1)
465 def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
467 Iterate only the leaf nodes in the tree.
469 >>> t = BinarySearchTree()
477 >>> for value in t.iterate_nodes_by_depth(2):
482 >>> for value in t.iterate_nodes_by_depth(3):
487 if self.root is not None:
488 yield from self._iterate_by_depth(self.root, depth)
490 def get_next_node(self, node: Node) -> Node:
492 Given a tree node, get the next greater node in the tree.
494 >>> t = BinarySearchTree()
512 >>> t.get_next_node(n).value
516 >>> t.get_next_node(n).value
520 if node.right is not None:
522 while x.left is not None:
526 path = self.parent_path(node)
527 assert path[-1] is not None
528 assert path[-1] == node
531 for ancestor in path:
532 assert ancestor is not None
533 if node != ancestor.right:
538 def _depth(self, node: Node, sofar: int) -> int:
539 depth_left = sofar + 1
540 depth_right = sofar + 1
541 if node.left is not None:
542 depth_left = self._depth(node.left, sofar + 1)
543 if node.right is not None:
544 depth_right = self._depth(node.right, sofar + 1)
545 return max(depth_left, depth_right)
549 Returns the max height (depth) of the tree in plies (edge distance
552 >>> t = BinarySearchTree()
574 if self.root is None:
576 return self._depth(self.root, 0)
585 node: Optional[Node],
586 has_right_sibling: bool,
589 viz = f'\n{padding}{pointer}{node.value}'
590 if has_right_sibling:
595 pointer_right = "└──"
596 if node.right is not None:
601 viz += self.repr_traverse(
602 padding, pointer_left, node.left, node.right is not None
604 viz += self.repr_traverse(padding, pointer_right, node.right, False)
610 Draw the tree in ASCII.
612 >>> t = BinarySearchTree()
629 if self.root is None:
632 ret = f'{self.root.value}'
633 pointer_right = "└──"
634 if self.root.right is None:
639 ret += self.repr_traverse(
640 '', pointer_left, self.root.left, self.root.left is not None
642 ret += self.repr_traverse('', pointer_right, self.root.right, False)