3 # © Copyright 2021-2022, Scott Gasch
5 """Binary search tree."""
7 from typing import Any, Generator, List, Optional
11 def __init__(self, value: Any) -> None:
13 Note: value can be anything as long as it is comparable.
14 Check out @functools.total_ordering.
16 self.left: Optional[Node] = None
17 self.right: Optional[Node] = None
21 class BinarySearchTree(object):
27 def get_root(self) -> Optional[Node]:
30 def insert(self, value: Any):
32 Insert something into the tree.
34 >>> t = BinarySearchTree()
41 >>> t.get_root().value
46 self.root = Node(value)
49 self._insert(value, self.root)
51 def _insert(self, value: Any, node: Node):
52 """Insertion helper"""
53 if value < node.value:
54 if node.left is not None:
55 self._insert(value, node.left)
57 node.left = Node(value)
60 if node.right is not None:
61 self._insert(value, node.right)
63 node.right = Node(value)
66 def __getitem__(self, value: Any) -> Optional[Node]:
68 Find an item in the tree and return its Node. Returns
69 None if the item is not in the tree.
71 >>> t = BinarySearchTree()
83 if self.root is not None:
84 return self._find(value, self.root)
87 def _find(self, value: Any, node: Node) -> Optional[Node]:
89 if value == node.value:
91 elif value < node.value and node.left is not None:
92 return self._find(value, node.left)
93 elif value > node.value and node.right is not None:
94 return self._find(value, node.right)
97 def _parent_path(self, current: Optional[Node], target: Node) -> List[Optional[Node]]:
100 ret: List[Optional[Node]] = [current]
101 if target.value == current.value:
103 elif target.value < current.value:
104 ret.extend(self._parent_path(current.left, target))
107 assert target.value > current.value
108 ret.extend(self._parent_path(current.right, target))
111 def parent_path(self, node: Node) -> List[Optional[Node]]:
112 """Return a list of nodes representing the path from
113 the tree's root to the node argument. If the node does
114 not exist in the tree for some reason, the last element
115 on the path will be None but the path will indicate the
116 ancestor path of that node were it inserted.
118 >>> t = BinarySearchTree()
136 >>> for x in t.parent_path(n):
144 >>> for x in t.parent_path(n):
145 ... if x is not None:
155 return self._parent_path(self.root, node)
157 def __delitem__(self, value: Any) -> bool:
159 Delete an item from the tree and preserve the BST property.
161 >>> t = BinarySearchTree()
178 >>> for value in t.iterate_inorder():
188 >>> del t[22] # Note: bool result is discarded
190 >>> for value in t.iterate_inorder():
199 >>> t.__delitem__(13)
201 >>> for value in t.iterate_inorder():
209 >>> t.__delitem__(75)
211 >>> for value in t.iterate_inorder():
223 >>> t.__delitem__(99)
227 if self.root is not None:
228 ret = self._delete(value, None, self.root)
236 def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool:
238 if node.value == value:
239 # Deleting a leaf node
240 if node.left is None and node.right is None:
241 if parent is not None:
242 if parent.left == node:
245 assert parent.right == node
249 # Node only has a right.
250 elif node.left is None:
251 assert node.right is not None
252 if parent is not None:
253 if parent.left == node:
254 parent.left = node.right
256 assert parent.right == node
257 parent.right = node.right
260 # Node only has a left.
261 elif node.right is None:
262 assert node.left is not None
263 if parent is not None:
264 if parent.left == node:
265 parent.left = node.left
267 assert parent.right == node
268 parent.right = node.left
271 # Node has both a left and right.
273 assert node.left is not None and node.right is not None
274 descendent = node.right
275 while descendent.left is not None:
276 descendent = descendent.left
277 node.value = descendent.value
278 return self._delete(node.value, node, node.right)
279 elif value < node.value and node.left is not None:
280 return self._delete(value, node, node.left)
281 elif value > node.value and node.right is not None:
282 return self._delete(value, node, node.right)
287 Returns the count of items in the tree.
289 >>> t = BinarySearchTree()
295 >>> t.__delitem__(50)
311 def __contains__(self, value: Any) -> bool:
313 Returns True if the item is in the tree; False otherwise.
316 return self.__getitem__(value) is not None
318 def _iterate_preorder(self, node: Node):
320 if node.left is not None:
321 yield from self._iterate_preorder(node.left)
322 if node.right is not None:
323 yield from self._iterate_preorder(node.right)
325 def _iterate_inorder(self, node: Node):
326 if node.left is not None:
327 yield from self._iterate_inorder(node.left)
329 if node.right is not None:
330 yield from self._iterate_inorder(node.right)
332 def _iterate_postorder(self, node: Node):
333 if node.left is not None:
334 yield from self._iterate_postorder(node.left)
335 if node.right is not None:
336 yield from self._iterate_postorder(node.right)
339 def iterate_preorder(self):
341 Yield the tree's items in a preorder traversal sequence.
343 >>> t = BinarySearchTree()
351 >>> for value in t.iterate_preorder():
361 if self.root is not None:
362 yield from self._iterate_preorder(self.root)
364 def iterate_inorder(self):
366 Yield the tree's items in a preorder traversal sequence.
368 >>> t = BinarySearchTree()
385 >>> for value in t.iterate_inorder():
396 if self.root is not None:
397 yield from self._iterate_inorder(self.root)
399 def iterate_postorder(self):
401 Yield the tree's items in a preorder traversal sequence.
403 >>> t = BinarySearchTree()
411 >>> for value in t.iterate_postorder():
421 if self.root is not None:
422 yield from self._iterate_postorder(self.root)
424 def _iterate_leaves(self, node: Node):
425 if node.left is not None:
426 yield from self._iterate_leaves(node.left)
427 if node.right is not None:
428 yield from self._iterate_leaves(node.right)
429 if node.left is None and node.right is None:
432 def iterate_leaves(self):
434 Iterate only the leaf nodes in the tree.
436 >>> t = BinarySearchTree()
444 >>> for value in t.iterate_leaves():
450 if self.root is not None:
451 yield from self._iterate_leaves(self.root)
453 def _iterate_by_depth(self, node: Node, depth: int):
458 if node.left is not None:
459 yield from self._iterate_by_depth(node.left, depth - 1)
460 if node.right is not None:
461 yield from self._iterate_by_depth(node.right, depth - 1)
463 def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
465 Iterate only the leaf nodes in the tree.
467 >>> t = BinarySearchTree()
475 >>> for value in t.iterate_nodes_by_depth(2):
480 >>> for value in t.iterate_nodes_by_depth(3):
485 if self.root is not None:
486 yield from self._iterate_by_depth(self.root, depth)
488 def get_next_node(self, node: Node) -> Node:
490 Given a tree node, get the next greater node in the tree.
492 >>> t = BinarySearchTree()
510 >>> t.get_next_node(n).value
514 >>> t.get_next_node(n).value
518 if node.right is not None:
520 while x.left is not None:
524 path = self.parent_path(node)
525 assert path[-1] is not None
526 assert path[-1] == node
529 for ancestor in path:
530 assert ancestor is not None
531 if node != ancestor.right:
536 def _depth(self, node: Node, sofar: int) -> int:
537 depth_left = sofar + 1
538 depth_right = sofar + 1
539 if node.left is not None:
540 depth_left = self._depth(node.left, sofar + 1)
541 if node.right is not None:
542 depth_right = self._depth(node.right, sofar + 1)
543 return max(depth_left, depth_right)
547 Returns the max height (depth) of the tree in plies (edge distance
550 >>> t = BinarySearchTree()
572 if self.root is None:
574 return self._depth(self.root, 0)
583 node: Optional[Node],
584 has_right_sibling: bool,
587 viz = f'\n{padding}{pointer}{node.value}'
588 if has_right_sibling:
593 pointer_right = "└──"
594 if node.right is not None:
599 viz += self.repr_traverse(padding, pointer_left, node.left, node.right is not None)
600 viz += self.repr_traverse(padding, pointer_right, node.right, False)
606 Draw the tree in ASCII.
608 >>> t = BinarySearchTree()
625 if self.root is None:
628 ret = f'{self.root.value}'
629 pointer_right = "└──"
630 if self.root.right is None:
635 ret += self.repr_traverse('', pointer_left, self.root.left, self.root.left is not None)
636 ret += self.repr_traverse('', pointer_right, self.root.right, False)
640 if __name__ == '__main__':