3 from typing import Any, Generator, List, Optional
7 def __init__(self, value: Any) -> None:
9 Note: value can be anything as long as it is comparable.
10 Check out @functools.total_ordering.
12 self.left: Optional[Node] = None
13 self.right: Optional[Node] = None
17 class BinarySearchTree(object):
23 def get_root(self) -> Optional[Node]:
26 def insert(self, value: Any):
28 Insert something into the tree.
30 >>> t = BinarySearchTree()
37 >>> t.get_root().value
42 self.root = Node(value)
45 self._insert(value, self.root)
47 def _insert(self, value: Any, node: Node):
48 """Insertion helper"""
49 if value < node.value:
50 if node.left is not None:
51 self._insert(value, node.left)
53 node.left = Node(value)
56 if node.right is not None:
57 self._insert(value, node.right)
59 node.right = Node(value)
62 def __getitem__(self, value: Any) -> Optional[Node]:
64 Find an item in the tree and return its Node. Returns
65 None if the item is not in the tree.
67 >>> t = BinarySearchTree()
79 if self.root is not None:
80 return self._find(value, self.root)
83 def _find(self, value: Any, node: Node) -> Optional[Node]:
85 if value == node.value:
87 elif value < node.value and node.left is not None:
88 return self._find(value, node.left)
89 elif value > node.value and node.right is not None:
90 return self._find(value, node.right)
94 self, current: Optional[Node], target: Node
95 ) -> List[Optional[Node]]:
98 ret: List[Optional[Node]] = [current]
99 if target.value == current.value:
101 elif target.value < current.value:
102 ret.extend(self._parent_path(current.left, target))
105 assert target.value > current.value
106 ret.extend(self._parent_path(current.right, target))
109 def parent_path(self, node: Node) -> List[Optional[Node]]:
110 """Return a list of nodes representing the path from
111 the tree's root to the node argument. If the node does
112 not exist in the tree for some reason, the last element
113 on the path will be None but the path will indicate the
114 ancestor path of that node were it inserted.
116 >>> t = BinarySearchTree()
134 >>> for x in t.parent_path(n):
142 >>> for x in t.parent_path(n):
143 ... if x is not None:
153 return self._parent_path(self.root, node)
155 def __delitem__(self, value: Any) -> bool:
157 Delete an item from the tree and preserve the BST property.
159 >>> t = BinarySearchTree()
176 >>> for value in t.iterate_inorder():
186 >>> del t[22] # Note: bool result is discarded
188 >>> for value in t.iterate_inorder():
197 >>> t.__delitem__(13)
199 >>> for value in t.iterate_inorder():
207 >>> t.__delitem__(75)
209 >>> for value in t.iterate_inorder():
221 >>> t.__delitem__(99)
225 if self.root is not None:
226 ret = self._delete(value, None, self.root)
234 def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool:
236 if node.value == value:
237 # Deleting a leaf node
238 if node.left is None and node.right is None:
239 if parent is not None:
240 if parent.left == node:
243 assert parent.right == node
247 # Node only has a right.
248 elif node.left is None:
249 assert node.right is not None
250 if parent is not None:
251 if parent.left == node:
252 parent.left = node.right
254 assert parent.right == node
255 parent.right = node.right
258 # Node only has a left.
259 elif node.right is None:
260 assert node.left is not None
261 if parent is not None:
262 if parent.left == node:
263 parent.left = node.left
265 assert parent.right == node
266 parent.right = node.left
269 # Node has both a left and right.
271 assert node.left is not None and node.right is not None
272 descendent = node.right
273 while descendent.left is not None:
274 descendent = descendent.left
275 node.value = descendent.value
276 return self._delete(node.value, node, node.right)
277 elif value < node.value and node.left is not None:
278 return self._delete(value, node, node.left)
279 elif value > node.value and node.right is not None:
280 return self._delete(value, node, node.right)
285 Returns the count of items in the tree.
287 >>> t = BinarySearchTree()
293 >>> t.__delitem__(50)
309 def __contains__(self, value: Any) -> bool:
311 Returns True if the item is in the tree; False otherwise.
314 return self.__getitem__(value) is not None
316 def _iterate_preorder(self, node: Node):
318 if node.left is not None:
319 yield from self._iterate_preorder(node.left)
320 if node.right is not None:
321 yield from self._iterate_preorder(node.right)
323 def _iterate_inorder(self, node: Node):
324 if node.left is not None:
325 yield from self._iterate_inorder(node.left)
327 if node.right is not None:
328 yield from self._iterate_inorder(node.right)
330 def _iterate_postorder(self, node: Node):
331 if node.left is not None:
332 yield from self._iterate_postorder(node.left)
333 if node.right is not None:
334 yield from self._iterate_postorder(node.right)
337 def iterate_preorder(self):
339 Yield the tree's items in a preorder traversal sequence.
341 >>> t = BinarySearchTree()
349 >>> for value in t.iterate_preorder():
359 if self.root is not None:
360 yield from self._iterate_preorder(self.root)
362 def iterate_inorder(self):
364 Yield the tree's items in a preorder traversal sequence.
366 >>> t = BinarySearchTree()
383 >>> for value in t.iterate_inorder():
394 if self.root is not None:
395 yield from self._iterate_inorder(self.root)
397 def iterate_postorder(self):
399 Yield the tree's items in a preorder traversal sequence.
401 >>> t = BinarySearchTree()
409 >>> for value in t.iterate_postorder():
419 if self.root is not None:
420 yield from self._iterate_postorder(self.root)
422 def _iterate_leaves(self, node: Node):
423 if node.left is not None:
424 yield from self._iterate_leaves(node.left)
425 if node.right is not None:
426 yield from self._iterate_leaves(node.right)
427 if node.left is None and node.right is None:
430 def iterate_leaves(self):
432 Iterate only the leaf nodes in the tree.
434 >>> t = BinarySearchTree()
442 >>> for value in t.iterate_leaves():
448 if self.root is not None:
449 yield from self._iterate_leaves(self.root)
451 def _iterate_by_depth(self, node: Node, depth: int):
456 if node.left is not None:
457 yield from self._iterate_by_depth(node.left, depth - 1)
458 if node.right is not None:
459 yield from self._iterate_by_depth(node.right, depth - 1)
461 def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
463 Iterate only the leaf nodes in the tree.
465 >>> t = BinarySearchTree()
473 >>> for value in t.iterate_nodes_by_depth(2):
478 >>> for value in t.iterate_nodes_by_depth(3):
483 if self.root is not None:
484 yield from self._iterate_by_depth(self.root, depth)
486 def get_next_node(self, node: Node) -> Node:
488 Given a tree node, get the next greater node in the tree.
490 >>> t = BinarySearchTree()
508 >>> t.get_next_node(n).value
512 >>> t.get_next_node(n).value
516 if node.right is not None:
518 while x.left is not None:
522 path = self.parent_path(node)
523 assert path[-1] is not None
524 assert path[-1] == node
527 for ancestor in path:
528 assert ancestor is not None
529 if node != ancestor.right:
534 def _depth(self, node: Node, sofar: int) -> int:
535 depth_left = sofar + 1
536 depth_right = sofar + 1
537 if node.left is not None:
538 depth_left = self._depth(node.left, sofar + 1)
539 if node.right is not None:
540 depth_right = self._depth(node.right, sofar + 1)
541 return max(depth_left, depth_right)
545 Returns the max height (depth) of the tree in plies (edge distance
548 >>> t = BinarySearchTree()
570 if self.root is None:
572 return self._depth(self.root, 0)
578 self, padding: str, pointer: str, node: Optional[Node], has_right_sibling: bool
581 viz = f'\n{padding}{pointer}{node.value}'
582 if has_right_sibling:
587 pointer_right = "└──"
588 if node.right is not None:
593 viz += self.repr_traverse(
594 padding, pointer_left, node.left, node.right is not None
596 viz += self.repr_traverse(padding, pointer_right, node.right, False)
602 Draw the tree in ASCII.
604 >>> t = BinarySearchTree()
621 if self.root is None:
624 ret = f'{self.root.value}'
625 pointer_right = "└──"
626 if self.root.right is None:
631 ret += self.repr_traverse(
632 '', pointer_left, self.root.left, self.root.left is not None
634 ret += self.repr_traverse('', pointer_right, self.root.right, False)
638 if __name__ == '__main__':