3 from typing import Any, Generator, List, Optional
7 def __init__(self, value: Any) -> None:
9 Note: value can be anything as long as it is comparable.
10 Check out @functools.total_ordering.
12 self.left: Optional[Node] = None
13 self.right: Optional[Node] = None
17 class BinarySearchTree(object):
23 def get_root(self) -> Optional[Node]:
26 def insert(self, value: Any):
28 Insert something into the tree.
30 >>> t = BinarySearchTree()
37 >>> t.get_root().value
42 self.root = Node(value)
45 self._insert(value, self.root)
47 def _insert(self, value: Any, node: Node):
48 """Insertion helper"""
49 if value < node.value:
50 if node.left is not None:
51 self._insert(value, node.left)
53 node.left = Node(value)
56 if node.right is not None:
57 self._insert(value, node.right)
59 node.right = Node(value)
62 def __getitem__(self, value: Any) -> Optional[Node]:
64 Find an item in the tree and return its Node. Returns
65 None if the item is not in the tree.
67 >>> t = BinarySearchTree()
79 if self.root is not None:
80 return self._find(value, self.root)
83 def _find(self, value: Any, node: Node) -> Optional[Node]:
85 if value == node.value:
87 elif value < node.value and node.left is not None:
88 return self._find(value, node.left)
89 elif value > node.value and node.right is not None:
90 return self._find(value, node.right)
93 def _parent_path(self, current: Optional[Node], target: Node) -> List[Optional[Node]]:
96 ret: List[Optional[Node]] = [current]
97 if target.value == current.value:
99 elif target.value < current.value:
100 ret.extend(self._parent_path(current.left, target))
103 assert target.value > current.value
104 ret.extend(self._parent_path(current.right, target))
107 def parent_path(self, node: Node) -> List[Optional[Node]]:
108 """Return a list of nodes representing the path from
109 the tree's root to the node argument. If the node does
110 not exist in the tree for some reason, the last element
111 on the path will be None but the path will indicate the
112 ancestor path of that node were it inserted.
114 >>> t = BinarySearchTree()
132 >>> for x in t.parent_path(n):
140 >>> for x in t.parent_path(n):
141 ... if x is not None:
151 return self._parent_path(self.root, node)
153 def __delitem__(self, value: Any) -> bool:
155 Delete an item from the tree and preserve the BST property.
157 >>> t = BinarySearchTree()
174 >>> for value in t.iterate_inorder():
184 >>> del t[22] # Note: bool result is discarded
186 >>> for value in t.iterate_inorder():
195 >>> t.__delitem__(13)
197 >>> for value in t.iterate_inorder():
205 >>> t.__delitem__(75)
207 >>> for value in t.iterate_inorder():
219 >>> t.__delitem__(99)
223 if self.root is not None:
224 ret = self._delete(value, None, self.root)
232 def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool:
234 if node.value == value:
235 # Deleting a leaf node
236 if node.left is None and node.right is None:
237 if parent is not None:
238 if parent.left == node:
241 assert parent.right == node
245 # Node only has a right.
246 elif node.left is None:
247 assert node.right is not None
248 if parent is not None:
249 if parent.left == node:
250 parent.left = node.right
252 assert parent.right == node
253 parent.right = node.right
256 # Node only has a left.
257 elif node.right is None:
258 assert node.left is not None
259 if parent is not None:
260 if parent.left == node:
261 parent.left = node.left
263 assert parent.right == node
264 parent.right = node.left
267 # Node has both a left and right.
269 assert node.left is not None and node.right is not None
270 descendent = node.right
271 while descendent.left is not None:
272 descendent = descendent.left
273 node.value = descendent.value
274 return self._delete(node.value, node, node.right)
275 elif value < node.value and node.left is not None:
276 return self._delete(value, node, node.left)
277 elif value > node.value and node.right is not None:
278 return self._delete(value, node, node.right)
283 Returns the count of items in the tree.
285 >>> t = BinarySearchTree()
291 >>> t.__delitem__(50)
307 def __contains__(self, value: Any) -> bool:
309 Returns True if the item is in the tree; False otherwise.
312 return self.__getitem__(value) is not None
314 def _iterate_preorder(self, node: Node):
316 if node.left is not None:
317 yield from self._iterate_preorder(node.left)
318 if node.right is not None:
319 yield from self._iterate_preorder(node.right)
321 def _iterate_inorder(self, node: Node):
322 if node.left is not None:
323 yield from self._iterate_inorder(node.left)
325 if node.right is not None:
326 yield from self._iterate_inorder(node.right)
328 def _iterate_postorder(self, node: Node):
329 if node.left is not None:
330 yield from self._iterate_postorder(node.left)
331 if node.right is not None:
332 yield from self._iterate_postorder(node.right)
335 def iterate_preorder(self):
337 Yield the tree's items in a preorder traversal sequence.
339 >>> t = BinarySearchTree()
347 >>> for value in t.iterate_preorder():
357 if self.root is not None:
358 yield from self._iterate_preorder(self.root)
360 def iterate_inorder(self):
362 Yield the tree's items in a preorder traversal sequence.
364 >>> t = BinarySearchTree()
381 >>> for value in t.iterate_inorder():
392 if self.root is not None:
393 yield from self._iterate_inorder(self.root)
395 def iterate_postorder(self):
397 Yield the tree's items in a preorder traversal sequence.
399 >>> t = BinarySearchTree()
407 >>> for value in t.iterate_postorder():
417 if self.root is not None:
418 yield from self._iterate_postorder(self.root)
420 def _iterate_leaves(self, node: Node):
421 if node.left is not None:
422 yield from self._iterate_leaves(node.left)
423 if node.right is not None:
424 yield from self._iterate_leaves(node.right)
425 if node.left is None and node.right is None:
428 def iterate_leaves(self):
430 Iterate only the leaf nodes in the tree.
432 >>> t = BinarySearchTree()
440 >>> for value in t.iterate_leaves():
446 if self.root is not None:
447 yield from self._iterate_leaves(self.root)
449 def _iterate_by_depth(self, node: Node, depth: int):
454 if node.left is not None:
455 yield from self._iterate_by_depth(node.left, depth - 1)
456 if node.right is not None:
457 yield from self._iterate_by_depth(node.right, depth - 1)
459 def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
461 Iterate only the leaf nodes in the tree.
463 >>> t = BinarySearchTree()
471 >>> for value in t.iterate_nodes_by_depth(2):
476 >>> for value in t.iterate_nodes_by_depth(3):
481 if self.root is not None:
482 yield from self._iterate_by_depth(self.root, depth)
484 def get_next_node(self, node: Node) -> Node:
486 Given a tree node, get the next greater node in the tree.
488 >>> t = BinarySearchTree()
506 >>> t.get_next_node(n).value
510 >>> t.get_next_node(n).value
514 if node.right is not None:
516 while x.left is not None:
520 path = self.parent_path(node)
521 assert path[-1] is not None
522 assert path[-1] == node
525 for ancestor in path:
526 assert ancestor is not None
527 if node != ancestor.right:
532 def _depth(self, node: Node, sofar: int) -> int:
533 depth_left = sofar + 1
534 depth_right = sofar + 1
535 if node.left is not None:
536 depth_left = self._depth(node.left, sofar + 1)
537 if node.right is not None:
538 depth_right = self._depth(node.right, sofar + 1)
539 return max(depth_left, depth_right)
543 Returns the max height (depth) of the tree in plies (edge distance
546 >>> t = BinarySearchTree()
568 if self.root is None:
570 return self._depth(self.root, 0)
579 node: Optional[Node],
580 has_right_sibling: bool,
583 viz = f'\n{padding}{pointer}{node.value}'
584 if has_right_sibling:
589 pointer_right = "└──"
590 if node.right is not None:
595 viz += self.repr_traverse(padding, pointer_left, node.left, node.right is not None)
596 viz += self.repr_traverse(padding, pointer_right, node.right, False)
602 Draw the tree in ASCII.
604 >>> t = BinarySearchTree()
621 if self.root is None:
624 ret = f'{self.root.value}'
625 pointer_right = "└──"
626 if self.root.right is None:
631 ret += self.repr_traverse('', pointer_left, self.root.left, self.root.left is not None)
632 ret += self.repr_traverse('', pointer_right, self.root.right, False)
636 if __name__ == '__main__':